| | import numpy as np |
| | import tensorflow as tf |
| | from PIL import Image |
| | import torch |
| | from datasets import load_metric |
| | from datasets import load_dataset |
| | from transformers import (ViTFeatureExtractor, ViTForImageClassification, TrainingArguments, Trainer, create_optimizer) |
| |
|
| | def convert_to_tf_tensor(image: Image): |
| | |
| | |
| | |
| | np_image = np.array(image) |
| | tf_image = tf.convert_to_tensor(np_image) |
| | tf_image = tf.image.resize(tf_image, [224, 224]) |
| | tf_image = tf.repeat(tf_image, 3, -1) |
| | return tf.expand_dims(tf_image, 0) |
| |
|
| | def preprocess(batch): |
| | |
| | inputs = feature_extractor( |
| | batch['img'], |
| | return_tensors='pt' |
| | ) |
| | |
| | inputs['label'] = batch['label'] |
| | return inputs |
| |
|
| | def collate_fn(batch): |
| | return { |
| | 'pixel_values': torch.stack([x['pixel_values'] for x in batch]), |
| | 'labels': torch.tensor([x['label'] for x in batch]) |
| | } |
| |
|
| | def compute_metrics(p): |
| | return metric.compute( |
| | predictions=np.argmax(p.predictions, axis=1), |
| | references=p.label_ids |
| | ) |
| |
|
| | if __name__ == '__main__': |
| | dataset_train = load_dataset( |
| | 'cifar10', |
| | split='train[:1000]', |
| | ignore_verifications=False |
| | ) |
| | print(dataset_train) |
| |
|
| | dataset_test = load_dataset( |
| | 'cifar10', |
| | split='test', |
| | ignore_verifications=True |
| | ) |
| | print(dataset_test) |
| |
|
| | |
| | num_classes = len(set(dataset_train['label'])) |
| | labels = dataset_train.features['label'] |
| | print(num_classes, labels) |
| |
|
| | print(dataset_train[0]['label'], labels.names[dataset_train[0]['label']]) |
| | |
| | model_id = 'google/vit-base-patch16-224-in21k' |
| | feature_extractor = ViTFeatureExtractor.from_pretrained( |
| | model_id |
| | ) |
| | print(feature_extractor) |
| |
|
| | example = feature_extractor( |
| | dataset_train[0]['img'], |
| | return_tensors='pt' |
| | ) |
| | print(example) |
| | print(example['pixel_values'].shape) |
| |
|
| | |
| | prepared_train = dataset_train.with_transform(preprocess) |
| | prepared_test = dataset_test.with_transform(preprocess) |
| |
|
| | |
| | metric = load_metric("accuracy") |
| |
|
| | training_args = TrainingArguments( |
| | output_dir="./cifar", |
| | per_device_train_batch_size=16, |
| | evaluation_strategy="steps", |
| | num_train_epochs=4, |
| | save_steps=100, |
| | eval_steps=100, |
| | logging_steps=10, |
| | learning_rate=2e-4, |
| | save_total_limit=2, |
| | remove_unused_columns=False, |
| | push_to_hub=True, |
| | load_best_model_at_end=True, |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | ) |
| |
|
| | labels = dataset_train.features['label'].names |
| |
|
| | model = ViTForImageClassification.from_pretrained( |
| | model_id, |
| | num_labels=len(labels) |
| | ) |
| |
|
| | trainer = Trainer( |
| | model=model, |
| | args=training_args, |
| | data_collator=collate_fn, |
| | compute_metrics=compute_metrics, |
| | train_dataset=prepared_train, |
| | eval_dataset=prepared_test, |
| | tokenizer=feature_extractor, |
| | ) |
| |
|
| | |
| | train_results = trainer.train() |
| | trainer.push_to_hub() |
| | |
| | trainer.save_model() |
| | trainer.log_metrics("train", train_results.metrics) |
| | trainer.save_metrics("train", train_results.metrics) |
| | |
| | trainer.save_state() |
| | batch_size = 16 |
| | num_epochs = 5 |
| | num_train_steps = len(dataset_train["train"]) * num_epochs |
| | learning_rate = 3e-5 |
| | weight_decay_rate = 0.01 |
| |
|
| | optimizer, lr_schedule = create_optimizer( |
| | init_lr=learning_rate, |
| | num_train_steps=num_train_steps, |
| | weight_decay_rate=weight_decay_rate, |
| | num_warmup_steps=0, |
| | ) |
| | tf_train_dataset = prepared_train.to_tf_dataset( |
| | features=["pixel_values"], |
| | labels=["label"], |
| | batch_size=batch_size, |
| | shuffle=True, |
| | collate_fn=collate_fn |
| | ) |
| |
|
| | tf_eval_dataset = prepared_test.to_tf_dataset( |
| | features=["pixel_values"], |
| | labels=["label"], |
| | batch_size=batch_size, |
| | shuffle=False, |
| | collate_fn=collate_fn |
| | ) |
| | loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) |
| | model.compile(optimizer=optimizer, loss=loss) |
| |
|
| | metrics = trainer.evaluate(prepared_test) |
| | trainer.log_metrics("eval", metrics) |
| | trainer.save_metrics("eval", metrics) |
| | |
| | eval_results = trainer.evaluate() |
| |
|
| | print(eval_results) |
| |
|