File size: 5,264 Bytes
6d16138 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 |
import numpy as np
import tensorflow as tf
from PIL import Image
import torch
from datasets import load_metric
from datasets import load_dataset
from transformers import (ViTFeatureExtractor, ViTForImageClassification, TrainingArguments, Trainer, create_optimizer)
def convert_to_tf_tensor(image: Image):
# np_image = np.array(image)
# tf_image = tf.convert_to_tensor(np_image)
# return tf.expand_dims(tf_image, 0)
np_image = np.array(image)
tf_image = tf.convert_to_tensor(np_image)
tf_image = tf.image.resize(tf_image, [224, 224]) # Resize to 224x224
tf_image = tf.repeat(tf_image, 3, -1) # Repeat along the color dimension to simulate 3 channels
return tf.expand_dims(tf_image, 0)
def preprocess(batch):
# take a list of PIL images and turn them to pixel values
inputs = feature_extractor(
batch['img'],
return_tensors='pt'
)
# include the labels
inputs['label'] = batch['label']
return inputs
def collate_fn(batch):
return {
'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
'labels': torch.tensor([x['label'] for x in batch])
}
def compute_metrics(p):
return metric.compute(
predictions=np.argmax(p.predictions, axis=1),
references=p.label_ids
)
if __name__ == '__main__':
dataset_train = load_dataset(
'cifar10',
split='train[:1000]', # training dataset
ignore_verifications=False # set to True if seeing splits Error
)
print(dataset_train)
dataset_test = load_dataset(
'cifar10',
split='test', # training dataset
ignore_verifications=True # set to True if seeing splits Error
)
print(dataset_test)
# check how many labels/number of classes
num_classes = len(set(dataset_train['label']))
labels = dataset_train.features['label']
print(num_classes, labels)
print(dataset_train[0]['label'], labels.names[dataset_train[0]['label']])
# import model
model_id = 'google/vit-base-patch16-224-in21k'
feature_extractor = ViTFeatureExtractor.from_pretrained(
model_id
)
print(feature_extractor)
example = feature_extractor(
dataset_train[0]['img'],
return_tensors='pt'
)
print(example)
print(example['pixel_values'].shape)
# transform the training dataset
prepared_train = dataset_train.with_transform(preprocess)
prepared_test = dataset_test.with_transform(preprocess)
# accuracy metric
metric = load_metric("accuracy")
training_args = TrainingArguments(
output_dir="./cifar",
per_device_train_batch_size=16,
evaluation_strategy="steps",
num_train_epochs=4,
save_steps=100,
eval_steps=100,
logging_steps=10,
learning_rate=2e-4,
save_total_limit=2,
remove_unused_columns=False,
push_to_hub=True,
load_best_model_at_end=True,
# output_dir='./cifar',
# per_device_train_batch_size=16,
# evaluation_strategy='steps',
# num_train_epochs=4,
# save_steps=100,
# eval_steps=100,
# logging_steps=10,
# learning_rate=2e-4,
# save_total_limit=2,
# remove_unused_columns=False,
# push_to_hub=True,
# push_to_hub_model_id="classify_images",
# load_best_model_at_end=True,
)
labels = dataset_train.features['label'].names
model = ViTForImageClassification.from_pretrained(
model_id, # classification head
num_labels=len(labels)
)
trainer = Trainer(
model=model,
args=training_args,
data_collator=collate_fn,
compute_metrics=compute_metrics,
train_dataset=prepared_train,
eval_dataset=prepared_test,
tokenizer=feature_extractor,
)
# Run the training
train_results = trainer.train()
trainer.push_to_hub()
# save tokenizer with the model
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
# save the trainer state
trainer.save_state()
batch_size = 16
num_epochs = 5
num_train_steps = len(dataset_train["train"]) * num_epochs
learning_rate = 3e-5
weight_decay_rate = 0.01
optimizer, lr_schedule = create_optimizer(
init_lr=learning_rate,
num_train_steps=num_train_steps,
weight_decay_rate=weight_decay_rate,
num_warmup_steps=0,
)
tf_train_dataset = prepared_train.to_tf_dataset(
features=["pixel_values"],
labels=["label"],
batch_size=batch_size,
shuffle=True,
collate_fn=collate_fn
)
tf_eval_dataset = prepared_test.to_tf_dataset(
features=["pixel_values"],
labels=["label"],
batch_size=batch_size,
shuffle=False,
collate_fn=collate_fn
)
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
model.compile(optimizer=optimizer, loss=loss)
metrics = trainer.evaluate(prepared_test)
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
# Evaluate the model
eval_results = trainer.evaluate()
print(eval_results)
|