import numpy as np import tensorflow as tf from PIL import Image import torch from datasets import load_metric from datasets import load_dataset from transformers import (ViTFeatureExtractor, ViTForImageClassification, TrainingArguments, Trainer, create_optimizer) def convert_to_tf_tensor(image: Image): # np_image = np.array(image) # tf_image = tf.convert_to_tensor(np_image) # return tf.expand_dims(tf_image, 0) np_image = np.array(image) tf_image = tf.convert_to_tensor(np_image) tf_image = tf.image.resize(tf_image, [224, 224]) # Resize to 224x224 tf_image = tf.repeat(tf_image, 3, -1) # Repeat along the color dimension to simulate 3 channels return tf.expand_dims(tf_image, 0) def preprocess(batch): # take a list of PIL images and turn them to pixel values inputs = feature_extractor( batch['img'], return_tensors='pt' ) # include the labels inputs['label'] = batch['label'] return inputs def collate_fn(batch): return { 'pixel_values': torch.stack([x['pixel_values'] for x in batch]), 'labels': torch.tensor([x['label'] for x in batch]) } def compute_metrics(p): return metric.compute( predictions=np.argmax(p.predictions, axis=1), references=p.label_ids ) if __name__ == '__main__': dataset_train = load_dataset( 'cifar10', split='train[:1000]', # training dataset ignore_verifications=False # set to True if seeing splits Error ) print(dataset_train) dataset_test = load_dataset( 'cifar10', split='test', # training dataset ignore_verifications=True # set to True if seeing splits Error ) print(dataset_test) # check how many labels/number of classes num_classes = len(set(dataset_train['label'])) labels = dataset_train.features['label'] print(num_classes, labels) print(dataset_train[0]['label'], labels.names[dataset_train[0]['label']]) # import model model_id = 'google/vit-base-patch16-224-in21k' feature_extractor = ViTFeatureExtractor.from_pretrained( model_id ) print(feature_extractor) example = feature_extractor( dataset_train[0]['img'], return_tensors='pt' ) print(example) print(example['pixel_values'].shape) # transform the training dataset prepared_train = dataset_train.with_transform(preprocess) prepared_test = dataset_test.with_transform(preprocess) # accuracy metric metric = load_metric("accuracy") training_args = TrainingArguments( output_dir="./cifar", per_device_train_batch_size=16, evaluation_strategy="steps", num_train_epochs=4, save_steps=100, eval_steps=100, logging_steps=10, learning_rate=2e-4, save_total_limit=2, remove_unused_columns=False, push_to_hub=True, load_best_model_at_end=True, # output_dir='./cifar', # per_device_train_batch_size=16, # evaluation_strategy='steps', # num_train_epochs=4, # save_steps=100, # eval_steps=100, # logging_steps=10, # learning_rate=2e-4, # save_total_limit=2, # remove_unused_columns=False, # push_to_hub=True, # push_to_hub_model_id="classify_images", # load_best_model_at_end=True, ) labels = dataset_train.features['label'].names model = ViTForImageClassification.from_pretrained( model_id, # classification head num_labels=len(labels) ) trainer = Trainer( model=model, args=training_args, data_collator=collate_fn, compute_metrics=compute_metrics, train_dataset=prepared_train, eval_dataset=prepared_test, tokenizer=feature_extractor, ) # Run the training train_results = trainer.train() trainer.push_to_hub() # save tokenizer with the model trainer.save_model() trainer.log_metrics("train", train_results.metrics) trainer.save_metrics("train", train_results.metrics) # save the trainer state trainer.save_state() batch_size = 16 num_epochs = 5 num_train_steps = len(dataset_train["train"]) * num_epochs learning_rate = 3e-5 weight_decay_rate = 0.01 optimizer, lr_schedule = create_optimizer( init_lr=learning_rate, num_train_steps=num_train_steps, weight_decay_rate=weight_decay_rate, num_warmup_steps=0, ) tf_train_dataset = prepared_train.to_tf_dataset( features=["pixel_values"], labels=["label"], batch_size=batch_size, shuffle=True, collate_fn=collate_fn ) tf_eval_dataset = prepared_test.to_tf_dataset( features=["pixel_values"], labels=["label"], batch_size=batch_size, shuffle=False, collate_fn=collate_fn ) loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) model.compile(optimizer=optimizer, loss=loss) metrics = trainer.evaluate(prepared_test) trainer.log_metrics("eval", metrics) trainer.save_metrics("eval", metrics) # Evaluate the model eval_results = trainer.evaluate() print(eval_results)