|
import numpy as np |
|
import tensorflow as tf |
|
from PIL import Image |
|
import torch |
|
from datasets import load_metric |
|
from datasets import load_dataset |
|
from transformers import (ViTFeatureExtractor, ViTForImageClassification, TrainingArguments, Trainer, create_optimizer) |
|
|
|
def convert_to_tf_tensor(image: Image): |
|
|
|
|
|
|
|
np_image = np.array(image) |
|
tf_image = tf.convert_to_tensor(np_image) |
|
tf_image = tf.image.resize(tf_image, [224, 224]) |
|
tf_image = tf.repeat(tf_image, 3, -1) |
|
return tf.expand_dims(tf_image, 0) |
|
|
|
def preprocess(batch): |
|
|
|
inputs = feature_extractor( |
|
batch['img'], |
|
return_tensors='pt' |
|
) |
|
|
|
inputs['label'] = batch['label'] |
|
return inputs |
|
|
|
def collate_fn(batch): |
|
return { |
|
'pixel_values': torch.stack([x['pixel_values'] for x in batch]), |
|
'labels': torch.tensor([x['label'] for x in batch]) |
|
} |
|
|
|
def compute_metrics(p): |
|
return metric.compute( |
|
predictions=np.argmax(p.predictions, axis=1), |
|
references=p.label_ids |
|
) |
|
|
|
if __name__ == '__main__': |
|
dataset_train = load_dataset( |
|
'cifar10', |
|
split='train[:1000]', |
|
ignore_verifications=False |
|
) |
|
print(dataset_train) |
|
|
|
dataset_test = load_dataset( |
|
'cifar10', |
|
split='test', |
|
ignore_verifications=True |
|
) |
|
print(dataset_test) |
|
|
|
|
|
num_classes = len(set(dataset_train['label'])) |
|
labels = dataset_train.features['label'] |
|
print(num_classes, labels) |
|
|
|
print(dataset_train[0]['label'], labels.names[dataset_train[0]['label']]) |
|
|
|
model_id = 'google/vit-base-patch16-224-in21k' |
|
feature_extractor = ViTFeatureExtractor.from_pretrained( |
|
model_id |
|
) |
|
print(feature_extractor) |
|
|
|
example = feature_extractor( |
|
dataset_train[0]['img'], |
|
return_tensors='pt' |
|
) |
|
print(example) |
|
print(example['pixel_values'].shape) |
|
|
|
|
|
prepared_train = dataset_train.with_transform(preprocess) |
|
prepared_test = dataset_test.with_transform(preprocess) |
|
|
|
|
|
metric = load_metric("accuracy") |
|
|
|
training_args = TrainingArguments( |
|
output_dir="./cifar", |
|
per_device_train_batch_size=16, |
|
evaluation_strategy="steps", |
|
num_train_epochs=4, |
|
save_steps=100, |
|
eval_steps=100, |
|
logging_steps=10, |
|
learning_rate=2e-4, |
|
save_total_limit=2, |
|
remove_unused_columns=False, |
|
push_to_hub=True, |
|
load_best_model_at_end=True, |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
) |
|
|
|
labels = dataset_train.features['label'].names |
|
|
|
model = ViTForImageClassification.from_pretrained( |
|
model_id, |
|
num_labels=len(labels) |
|
) |
|
|
|
trainer = Trainer( |
|
model=model, |
|
args=training_args, |
|
data_collator=collate_fn, |
|
compute_metrics=compute_metrics, |
|
train_dataset=prepared_train, |
|
eval_dataset=prepared_test, |
|
tokenizer=feature_extractor, |
|
) |
|
|
|
|
|
train_results = trainer.train() |
|
trainer.push_to_hub() |
|
|
|
trainer.save_model() |
|
trainer.log_metrics("train", train_results.metrics) |
|
trainer.save_metrics("train", train_results.metrics) |
|
|
|
trainer.save_state() |
|
batch_size = 16 |
|
num_epochs = 5 |
|
num_train_steps = len(dataset_train["train"]) * num_epochs |
|
learning_rate = 3e-5 |
|
weight_decay_rate = 0.01 |
|
|
|
optimizer, lr_schedule = create_optimizer( |
|
init_lr=learning_rate, |
|
num_train_steps=num_train_steps, |
|
weight_decay_rate=weight_decay_rate, |
|
num_warmup_steps=0, |
|
) |
|
tf_train_dataset = prepared_train.to_tf_dataset( |
|
features=["pixel_values"], |
|
labels=["label"], |
|
batch_size=batch_size, |
|
shuffle=True, |
|
collate_fn=collate_fn |
|
) |
|
|
|
tf_eval_dataset = prepared_test.to_tf_dataset( |
|
features=["pixel_values"], |
|
labels=["label"], |
|
batch_size=batch_size, |
|
shuffle=False, |
|
collate_fn=collate_fn |
|
) |
|
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) |
|
model.compile(optimizer=optimizer, loss=loss) |
|
|
|
metrics = trainer.evaluate(prepared_test) |
|
trainer.log_metrics("eval", metrics) |
|
trainer.save_metrics("eval", metrics) |
|
|
|
eval_results = trainer.evaluate() |
|
|
|
print(eval_results) |
|
|