classify_images / Hugging.py
hula07's picture
Upload Hugging.py
6d16138
import numpy as np
import tensorflow as tf
from PIL import Image
import torch
from datasets import load_metric
from datasets import load_dataset
from transformers import (ViTFeatureExtractor, ViTForImageClassification, TrainingArguments, Trainer, create_optimizer)
def convert_to_tf_tensor(image: Image):
# np_image = np.array(image)
# tf_image = tf.convert_to_tensor(np_image)
# return tf.expand_dims(tf_image, 0)
np_image = np.array(image)
tf_image = tf.convert_to_tensor(np_image)
tf_image = tf.image.resize(tf_image, [224, 224]) # Resize to 224x224
tf_image = tf.repeat(tf_image, 3, -1) # Repeat along the color dimension to simulate 3 channels
return tf.expand_dims(tf_image, 0)
def preprocess(batch):
# take a list of PIL images and turn them to pixel values
inputs = feature_extractor(
batch['img'],
return_tensors='pt'
)
# include the labels
inputs['label'] = batch['label']
return inputs
def collate_fn(batch):
return {
'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
'labels': torch.tensor([x['label'] for x in batch])
}
def compute_metrics(p):
return metric.compute(
predictions=np.argmax(p.predictions, axis=1),
references=p.label_ids
)
if __name__ == '__main__':
dataset_train = load_dataset(
'cifar10',
split='train[:1000]', # training dataset
ignore_verifications=False # set to True if seeing splits Error
)
print(dataset_train)
dataset_test = load_dataset(
'cifar10',
split='test', # training dataset
ignore_verifications=True # set to True if seeing splits Error
)
print(dataset_test)
# check how many labels/number of classes
num_classes = len(set(dataset_train['label']))
labels = dataset_train.features['label']
print(num_classes, labels)
print(dataset_train[0]['label'], labels.names[dataset_train[0]['label']])
# import model
model_id = 'google/vit-base-patch16-224-in21k'
feature_extractor = ViTFeatureExtractor.from_pretrained(
model_id
)
print(feature_extractor)
example = feature_extractor(
dataset_train[0]['img'],
return_tensors='pt'
)
print(example)
print(example['pixel_values'].shape)
# transform the training dataset
prepared_train = dataset_train.with_transform(preprocess)
prepared_test = dataset_test.with_transform(preprocess)
# accuracy metric
metric = load_metric("accuracy")
training_args = TrainingArguments(
output_dir="./cifar",
per_device_train_batch_size=16,
evaluation_strategy="steps",
num_train_epochs=4,
save_steps=100,
eval_steps=100,
logging_steps=10,
learning_rate=2e-4,
save_total_limit=2,
remove_unused_columns=False,
push_to_hub=True,
load_best_model_at_end=True,
# output_dir='./cifar',
# per_device_train_batch_size=16,
# evaluation_strategy='steps',
# num_train_epochs=4,
# save_steps=100,
# eval_steps=100,
# logging_steps=10,
# learning_rate=2e-4,
# save_total_limit=2,
# remove_unused_columns=False,
# push_to_hub=True,
# push_to_hub_model_id="classify_images",
# load_best_model_at_end=True,
)
labels = dataset_train.features['label'].names
model = ViTForImageClassification.from_pretrained(
model_id, # classification head
num_labels=len(labels)
)
trainer = Trainer(
model=model,
args=training_args,
data_collator=collate_fn,
compute_metrics=compute_metrics,
train_dataset=prepared_train,
eval_dataset=prepared_test,
tokenizer=feature_extractor,
)
# Run the training
train_results = trainer.train()
trainer.push_to_hub()
# save tokenizer with the model
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
# save the trainer state
trainer.save_state()
batch_size = 16
num_epochs = 5
num_train_steps = len(dataset_train["train"]) * num_epochs
learning_rate = 3e-5
weight_decay_rate = 0.01
optimizer, lr_schedule = create_optimizer(
init_lr=learning_rate,
num_train_steps=num_train_steps,
weight_decay_rate=weight_decay_rate,
num_warmup_steps=0,
)
tf_train_dataset = prepared_train.to_tf_dataset(
features=["pixel_values"],
labels=["label"],
batch_size=batch_size,
shuffle=True,
collate_fn=collate_fn
)
tf_eval_dataset = prepared_test.to_tf_dataset(
features=["pixel_values"],
labels=["label"],
batch_size=batch_size,
shuffle=False,
collate_fn=collate_fn
)
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
model.compile(optimizer=optimizer, loss=loss)
metrics = trainer.evaluate(prepared_test)
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
# Evaluate the model
eval_results = trainer.evaluate()
print(eval_results)