--- base_model: openai/clip-vit-base-patch32 tags: - generated_from_trainer metrics: - accuracy model-index: - name: outputs results: [] license: apache-2.0 datasets: - Andron00e/CIFAR10-custom language: - en library_name: transformers --- # outputs This model is a fine-tuned version of [openai/clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32) on an [CIFAR10](https://huggingface.co/datasets/Andron00e/CIFAR10-custom) dataset. It achieves the following results on the evaluation set: - Loss: 0.8115 - Accuracy: 0.8255 ## Model description ## Training and evaluation data ## Training procedure ### Training hyperparameters The following hyperparameters were used during training: - learning_rate: 0.0002 - train_batch_size: 10 - eval_batch_size: 8 - seed: 42 - optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08 - lr_scheduler_type: linear - num_epochs: 4 ### Training results | Training Loss | Epoch | Step | Validation Loss | Accuracy | |:-------------:|:-----:|:-----:|:---------------:|:--------:| | 1.7258 | 0.02 | 100 | 1.6999 | 0.8048 | | 1.669 | 0.04 | 200 | 1.6798 | 0.8055 | | 1.6704 | 0.06 | 300 | 1.6599 | 0.8053 | | 1.6655 | 0.08 | 400 | 1.6407 | 0.8047 | | 1.5754 | 0.1 | 500 | 1.6223 | 0.809 | | 1.6159 | 0.12 | 600 | 1.6040 | 0.8068 | | 1.5663 | 0.15 | 700 | 1.5858 | 0.8073 | | 1.5426 | 0.17 | 800 | 1.5677 | 0.8095 | | 1.5794 | 0.19 | 900 | 1.5506 | 0.808 | | 1.5504 | 0.21 | 1000 | 1.5342 | 0.8035 | | 1.554 | 0.23 | 1100 | 1.5179 | 0.802 | | 1.4831 | 0.25 | 1200 | 1.5022 | 0.7972 | | 1.4718 | 0.27 | 1300 | 1.4867 | 0.7955 | | 1.5206 | 0.29 | 1400 | 1.4716 | 0.796 | | 1.4534 | 0.31 | 1500 | 1.4567 | 0.7963 | | 1.3932 | 0.33 | 1600 | 1.4427 | 0.7875 | | 1.4635 | 0.35 | 1700 | 1.4289 | 0.789 | | 1.4339 | 0.38 | 1800 | 1.4151 | 0.793 | | 1.4492 | 0.4 | 1900 | 1.4016 | 0.7973 | | 1.4369 | 0.42 | 2000 | 1.3881 | 0.8018 | | 1.4007 | 0.44 | 2100 | 1.3754 | 0.801 | | 1.3697 | 0.46 | 2200 | 1.3627 | 0.8025 | | 1.3298 | 0.48 | 2300 | 1.3505 | 0.8048 | | 1.2809 | 0.5 | 2400 | 1.3386 | 0.8068 | | 1.2989 | 0.52 | 2500 | 1.3272 | 0.8067 | | 1.2958 | 0.54 | 2600 | 1.3159 | 0.81 | | 1.3072 | 0.56 | 2700 | 1.3048 | 0.8097 | | 1.2545 | 0.58 | 2800 | 1.2943 | 0.809 | | 1.2722 | 0.6 | 2900 | 1.2834 | 0.8112 | | 1.2628 | 0.62 | 3000 | 1.2732 | 0.8102 | | 1.2357 | 0.65 | 3100 | 1.2632 | 0.8105 | | 1.3189 | 0.67 | 3200 | 1.2532 | 0.8093 | | 1.2465 | 0.69 | 3300 | 1.2436 | 0.8097 | | 1.2579 | 0.71 | 3400 | 1.2342 | 0.8087 | | 1.1963 | 0.73 | 3500 | 1.2249 | 0.8085 | | 1.1701 | 0.75 | 3600 | 1.2159 | 0.8092 | | 1.2117 | 0.77 | 3700 | 1.2069 | 0.8113 | | 1.1907 | 0.79 | 3800 | 1.1984 | 0.8112 | | 1.1903 | 0.81 | 3900 | 1.1902 | 0.8115 | | 1.2357 | 0.83 | 4000 | 1.1821 | 0.8115 | | 1.1924 | 0.85 | 4100 | 1.1738 | 0.8117 | | 1.1914 | 0.88 | 4200 | 1.1657 | 0.8133 | | 1.1536 | 0.9 | 4300 | 1.1580 | 0.8148 | | 1.1893 | 0.92 | 4400 | 1.1505 | 0.8158 | | 1.1811 | 0.94 | 4500 | 1.1433 | 0.8158 | | 1.0182 | 0.96 | 4600 | 1.1358 | 0.8165 | | 1.0396 | 0.98 | 4700 | 1.1287 | 0.8158 | | 1.1502 | 1.0 | 4800 | 1.1217 | 0.816 | | 1.1764 | 1.02 | 4900 | 1.1147 | 0.8158 | | 1.1508 | 1.04 | 5000 | 1.1080 | 0.8152 | | 1.0518 | 1.06 | 5100 | 1.1015 | 0.8155 | | 1.0648 | 1.08 | 5200 | 1.0952 | 0.816 | | 1.1631 | 1.1 | 5300 | 1.0889 | 0.8153 | | 1.0629 | 1.12 | 5400 | 1.0826 | 0.8152 | | 1.1151 | 1.15 | 5500 | 1.0771 | 0.815 | | 1.1377 | 1.17 | 5600 | 1.0711 | 0.8145 | | 1.0353 | 1.19 | 5700 | 1.0652 | 0.8158 | | 1.068 | 1.21 | 5800 | 1.0594 | 0.815 | | 1.0834 | 1.23 | 5900 | 1.0538 | 0.8162 | | 1.0002 | 1.25 | 6000 | 1.0483 | 0.8165 | | 1.0024 | 1.27 | 6100 | 1.0428 | 0.817 | | 1.0609 | 1.29 | 6200 | 1.0376 | 0.817 | | 1.0901 | 1.31 | 6300 | 1.0324 | 0.816 | | 1.0772 | 1.33 | 6400 | 1.0275 | 0.8173 | | 0.9434 | 1.35 | 6500 | 1.0226 | 0.817 | | 0.9692 | 1.38 | 6600 | 1.0178 | 0.8157 | | 1.0461 | 1.4 | 6700 | 1.0131 | 0.8155 | | 1.0583 | 1.42 | 6800 | 1.0086 | 0.8143 | | 0.9369 | 1.44 | 6900 | 1.0042 | 0.8157 | | 1.0685 | 1.46 | 7000 | 0.9998 | 0.8152 | | 1.062 | 1.48 | 7100 | 0.9955 | 0.8153 | | 1.0394 | 1.5 | 7200 | 0.9912 | 0.8142 | | 1.031 | 1.52 | 7300 | 0.9870 | 0.8157 | | 0.9556 | 1.54 | 7400 | 0.9829 | 0.8155 | | 0.9846 | 1.56 | 7500 | 0.9789 | 0.8152 | | 0.9995 | 1.58 | 7600 | 0.9750 | 0.8158 | | 1.0273 | 1.6 | 7700 | 0.9711 | 0.8163 | | 0.9383 | 1.62 | 7800 | 0.9674 | 0.817 | | 0.951 | 1.65 | 7900 | 0.9634 | 0.8163 | | 0.9457 | 1.67 | 8000 | 0.9598 | 0.8167 | | 1.012 | 1.69 | 8100 | 0.9563 | 0.816 | | 0.9683 | 1.71 | 8200 | 0.9529 | 0.8158 | | 0.9582 | 1.73 | 8300 | 0.9495 | 0.8157 | | 0.9005 | 1.75 | 8400 | 0.9461 | 0.8162 | | 0.888 | 1.77 | 8500 | 0.9428 | 0.8175 | | 0.9267 | 1.79 | 8600 | 0.9396 | 0.8168 | | 0.9298 | 1.81 | 8700 | 0.9364 | 0.8168 | | 1.0072 | 1.83 | 8800 | 0.9334 | 0.8167 | | 0.9425 | 1.85 | 8900 | 0.9303 | 0.8158 | | 0.9729 | 1.88 | 9000 | 0.9273 | 0.8168 | | 0.9104 | 1.9 | 9100 | 0.9244 | 0.8175 | | 0.9153 | 1.92 | 9200 | 0.9216 | 0.817 | | 0.9115 | 1.94 | 9300 | 0.9188 | 0.8165 | | 0.9079 | 1.96 | 9400 | 0.9161 | 0.8168 | | 0.8453 | 1.98 | 9500 | 0.9133 | 0.8175 | | 0.8323 | 2.0 | 9600 | 0.9107 | 0.817 | | 0.9071 | 2.02 | 9700 | 0.9080 | 0.8183 | | 0.9331 | 2.04 | 9800 | 0.9054 | 0.8185 | | 0.886 | 2.06 | 9900 | 0.9029 | 0.8193 | | 0.8562 | 2.08 | 10000 | 0.9006 | 0.8183 | | 0.8904 | 2.1 | 10100 | 0.8980 | 0.8193 | | 0.8247 | 2.12 | 10200 | 0.8956 | 0.8188 | | 0.8114 | 2.15 | 10300 | 0.8934 | 0.8202 | | 0.96 | 2.17 | 10400 | 0.8912 | 0.8198 | | 0.9326 | 2.19 | 10500 | 0.8889 | 0.8198 | | 0.8057 | 2.21 | 10600 | 0.8867 | 0.8195 | | 0.8266 | 2.23 | 10700 | 0.8846 | 0.8188 | | 0.7909 | 2.25 | 10800 | 0.8823 | 0.82 | | 0.886 | 2.27 | 10900 | 0.8803 | 0.8192 | | 0.8691 | 2.29 | 11000 | 0.8783 | 0.8193 | | 0.8676 | 2.31 | 11100 | 0.8763 | 0.8187 | | 0.8147 | 2.33 | 11200 | 0.8744 | 0.819 | | 0.7723 | 2.35 | 11300 | 0.8725 | 0.8195 | | 0.9222 | 2.38 | 11400 | 0.8705 | 0.8188 | | 0.9692 | 2.4 | 11500 | 0.8687 | 0.8195 | | 0.8792 | 2.42 | 11600 | 0.8669 | 0.8188 | | 0.939 | 2.44 | 11700 | 0.8650 | 0.8193 | | 0.9093 | 2.46 | 11800 | 0.8633 | 0.8188 | | 0.7794 | 2.48 | 11900 | 0.8616 | 0.8182 | | 0.8572 | 2.5 | 12000 | 0.8599 | 0.8182 | | 0.9035 | 2.52 | 12100 | 0.8582 | 0.8185 | | 0.8063 | 2.54 | 12200 | 0.8566 | 0.8193 | | 0.8935 | 2.56 | 12300 | 0.8550 | 0.8195 | | 0.7991 | 2.58 | 12400 | 0.8535 | 0.8192 | | 0.856 | 2.6 | 12500 | 0.8520 | 0.8195 | | 0.8374 | 2.62 | 12600 | 0.8505 | 0.8197 | | 0.8418 | 2.65 | 12700 | 0.8490 | 0.8203 | | 0.9232 | 2.67 | 12800 | 0.8475 | 0.8208 | | 0.8335 | 2.69 | 12900 | 0.8462 | 0.8207 | | 0.8659 | 2.71 | 13000 | 0.8449 | 0.8205 | | 0.9798 | 2.73 | 13100 | 0.8435 | 0.8205 | | 0.7288 | 2.75 | 13200 | 0.8423 | 0.8205 | | 0.9086 | 2.77 | 13300 | 0.8411 | 0.821 | | 0.7912 | 2.79 | 13400 | 0.8398 | 0.8205 | | 0.8675 | 2.81 | 13500 | 0.8386 | 0.8202 | | 0.8045 | 2.83 | 13600 | 0.8374 | 0.8198 | | 0.8421 | 2.85 | 13700 | 0.8362 | 0.8202 | | 0.7453 | 2.88 | 13800 | 0.8350 | 0.8202 | | 0.7348 | 2.9 | 13900 | 0.8339 | 0.8203 | | 0.8977 | 2.92 | 14000 | 0.8328 | 0.8205 | | 0.859 | 2.94 | 14100 | 0.8318 | 0.821 | | 0.8571 | 2.96 | 14200 | 0.8307 | 0.8212 | | 0.8158 | 2.98 | 14300 | 0.8297 | 0.8215 | | 0.8635 | 3.0 | 14400 | 0.8287 | 0.8215 | | 0.9095 | 3.02 | 14500 | 0.8277 | 0.8215 | | 0.8491 | 3.04 | 14600 | 0.8268 | 0.8217 | | 0.9136 | 3.06 | 14700 | 0.8259 | 0.8223 | | 0.8652 | 3.08 | 14800 | 0.8250 | 0.8218 | | 0.9299 | 3.1 | 14900 | 0.8242 | 0.8215 | | 0.8259 | 3.12 | 15000 | 0.8233 | 0.8215 | | 0.775 | 3.15 | 15100 | 0.8225 | 0.8222 | | 0.801 | 3.17 | 15200 | 0.8217 | 0.8217 | | 0.8535 | 3.19 | 15300 | 0.8209 | 0.8215 | | 0.7973 | 3.21 | 15400 | 0.8202 | 0.8217 | | 0.8937 | 3.23 | 15500 | 0.8195 | 0.8213 | | 0.7632 | 3.25 | 15600 | 0.8188 | 0.821 | | 0.8117 | 3.27 | 15700 | 0.8181 | 0.8212 | | 0.8941 | 3.29 | 15800 | 0.8174 | 0.8217 | | 0.802 | 3.31 | 15900 | 0.8168 | 0.8225 | | 0.8303 | 3.33 | 16000 | 0.8161 | 0.8217 | | 0.8264 | 3.35 | 16100 | 0.8155 | 0.8218 | | 0.8411 | 3.38 | 16200 | 0.8149 | 0.8213 | | 0.9378 | 3.4 | 16300 | 0.8143 | 0.8218 | | 0.8514 | 3.42 | 16400 | 0.8138 | 0.8217 | | 0.7313 | 3.44 | 16500 | 0.8133 | 0.8222 | | 0.8238 | 3.46 | 16600 | 0.8128 | 0.8218 | | 0.7876 | 3.48 | 16700 | 0.8123 | 0.8222 | | 0.8364 | 3.5 | 16800 | 0.8118 | 0.8222 | | 0.7049 | 3.52 | 16900 | 0.8114 | 0.8222 | | 0.9101 | 3.54 | 17000 | 0.8109 | 0.8218 | | 0.7984 | 3.56 | 17100 | 0.8105 | 0.822 | | 0.85 | 3.58 | 17200 | 0.8101 | 0.8218 | | 0.8677 | 3.6 | 17300 | 0.8098 | 0.822 | | 0.8797 | 3.62 | 17400 | 0.8094 | 0.8218 | | 0.7847 | 3.65 | 17500 | 0.8091 | 0.8222 | | 0.8415 | 3.67 | 17600 | 0.8088 | 0.8218 | | 0.8702 | 3.69 | 17700 | 0.8085 | 0.8222 | | 0.8979 | 3.71 | 17800 | 0.8082 | 0.8222 | | 0.8387 | 3.73 | 17900 | 0.8080 | 0.8222 | | 0.8467 | 3.75 | 18000 | 0.8077 | 0.822 | | 0.8729 | 3.77 | 18100 | 0.8075 | 0.822 | | 0.8291 | 3.79 | 18200 | 0.8073 | 0.8222 | | 0.7897 | 3.81 | 18300 | 0.8072 | 0.8222 | | 0.8039 | 3.83 | 18400 | 0.8070 | 0.822 | | 0.771 | 3.85 | 18500 | 0.8069 | 0.8223 | | 0.7704 | 3.88 | 18600 | 0.8067 | 0.8223 | | 0.7695 | 3.9 | 18700 | 0.8066 | 0.8223 | | 0.8958 | 3.92 | 18800 | 0.8066 | 0.8223 | | 0.8342 | 3.94 | 18900 | 0.8065 | 0.8223 | | 0.8725 | 3.96 | 19000 | 0.8064 | 0.8225 | | 0.8657 | 3.98 | 19100 | 0.8064 | 0.8225 | | 0.779 | 4.0 | 19200 | 0.8064 | 0.8225 | ### Framework versions - Transformers 4.35.2 - Pytorch 2.1.0+cu118 - Datasets 2.15.0 - Tokenizers 0.15.0 ### Example of usage Simple demo for Google Colab ```python !pip install datasets transformers[torch] accelerate -U !git clone https://github.com/Andron00e/CLIPForImageClassification %cd CLIPForImageClassification/clip_for_classification import torch from transformers import TrainingArguments from datasets import load_dataset, load_metric from transformers import CLIPProcessor, AutoModelForImageClassification from modeling_clipforimageclassification import CLIPForImageClassification processor = CLIPProcessor.from_pretrained("Andron00e/CLIPForImageClassification-v1") model = CLIPForImageClassification.from_pretrained("Andron00e/CLIPForImageClassification-v1", 10) dataset = load_dataset("Andron00e/CIFAR10-custom") dataset = dataset["train"].train_test_split(test_size=0.2) from datasets import DatasetDict val_test = dataset["test"].train_test_split(test_size=0.5) dataset = DatasetDict({ "train": dataset["train"], "validation": val_test["train"], "test": val_test["test"], }) classes = {0: "airplane", 1: "automobile", 2: "bird", 3: "cat", 4: "deer", 5: "dog", 6: "frog", 7: "horse", 8: "ship", 9: "truck"} def transform(example_batch): inputs = processor(text=[classes[x] for x in example_batch['labels']], images=[x for x in example_batch['image']], padding=True, return_tensors='pt') inputs['labels'] = example_batch['labels'] return inputs def collate_fn(batch): return { 'input_ids': torch.stack([x['input_ids'] for x in batch]), 'attention_mask': torch.stack([x['attention_mask'] for x in batch]), 'pixel_values': torch.stack([x['pixel_values'] for x in batch]), 'labels': torch.tensor([x['labels'] for x in batch]) } metric = load_metric("accuracy") def compute_metrics(p): return metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids) training_args = TrainingArguments( output_dir="./outputs", per_device_train_batch_size=16, evaluation_strategy="steps", num_train_epochs=4, fp16=False, save_steps=100, eval_steps=100, logging_steps=10, learning_rate=2e-4, save_total_limit=2, remove_unused_columns=False, push_to_hub=False, report_to='tensorboard', load_best_model_at_end=True, ) from transformers import Trainer trainer = Trainer( model=model, args=training_args, data_collator=collate_fn, compute_metrics=compute_metrics, train_dataset=dataset.with_transform(transform)["train"], eval_dataset=dataset.with_transform(transform)["validation"], tokenizer=model.processor, ) train_results = trainer.train() trainer.save_model() trainer.log_metrics("train", train_results.metrics) trainer.save_metrics("train", train_results.metrics) trainer.save_state() metrics = trainer.evaluate(processed_dataset['test']) trainer.log_metrics("eval", metrics) trainer.save_metrics("eval", metrics) %cd .. %cd .. ```