In [None]:
%reload_ext autoreload
%autoreload 2

In [None]:
# imports

import pandas as pd
import os
from pathlib import Path
from PIL import Image
import shutil
from logging import root
from PIL import Image
from pathlib import Path
import pandas as pd
import torch
from torch.utils.data import Dataset
from PIL import Image
from transformers import (
 Seq2SeqTrainer,
 Seq2SeqTrainingArguments,
 get_linear_schedule_with_warmup,
 AutoFeatureExtractor,
 AutoTokenizer,
 ViTFeatureExtractor,
 VisionEncoderDecoderModel,
 default_data_collator,
)
from transformers.optimization import AdamW

from box import Box
import inspect


In [None]:
# custom functions

class ImageCaptionDataset(Dataset):
 def __init__(
 self, df, feature_extractor, tokenizer, images_dir, max_target_length=128
 ):
 self.df = df
 self.feature_extractor = feature_extractor
 self.tokenizer = tokenizer
 self.images_dir = images_dir
 self.max_target_length = max_target_length

 def __len__(self):
 return len(self.df)

 def __getitem__(self, idx):
 filename = self.df["filename"][idx]
 text = self.df["text"][idx]
 # prepare image (i.e. resize + normalize)
 image = Image.open(self.images_dir / filename).convert("RGB")
 pixel_values = self.feature_extractor(image, return_tensors="pt").pixel_values
 # add labels (input_ids) by encoding the text
 labels = self.tokenizer(
 text,
 padding="max_length",
 truncation=True,
 max_length=self.max_target_length,
 ).input_ids
 # important: make sure that PAD tokens are ignored by the loss function
 labels = [
 label if label != self.tokenizer.pad_token_id else -100 for label in labels
 ]

 encoding = {
 "pixel_values": pixel_values.squeeze(),
 "labels": torch.tensor(labels),
 }
 return encoding



def predict(image, max_length=64, num_beams=4):

 pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
 pixel_values = pixel_values.to(device)

 with torch.no_grad():
 output_ids = model.generate(
 pixel_values,
 max_length=max_length,
 num_beams=num_beams,
 return_dict_in_generate=True,
 ).sequences

 preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
 preds = [pred.strip() for pred in preds]

 return preds


In [None]:
data_dir = Path("datasets").resolve()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
# arguments pertaining to what data we are going to input our model for training and eval.

data_training_args = {
 # The maximum total sequence length for target text after tokenization. Sequences longer than this will be truncated, sequences shorter will be padded.
 "max_target_length": 64,

 # Number of beams to use for evaluation. This argument will be passed to model.generate which is used during evaluate and predict.
 "num_beams": 4,

 # Folder with all the images
 "images_dir": data_dir / "images",
}

data_training_args = Box(data_training_args)

In [None]:
# arguments pertaining to which model/config/tokenizer we are going to fine-tune from.

model_args = {

 # Path to pretrained model or model identifier from huggingface.co/models"
 "encoder_model_name_or_path": "google/vit-base-patch16-224-in21k",

 # Path to pretrained model or model identifier from huggingface.co/models"
 "decoder_model_name_or_path": "gpt2",

 # If set to int > 0, all ngrams of that size can only occur once.
 "no_repeat_ngram_size": 3,

 # Exponential penalty to the length that will be used by default in the generate method of the model.
 "length_penalty": 2.0,
}

model_args = Box(model_args)

In [None]:
# arguments pertaining to Trainer class. Refer: https://huggingface.co/docs/transformers/main_classes/trainer#transformers.Seq2SeqTrainingArguments

training_args = {
 "num_train_epochs": 5,
 "per_device_train_batch_size": 32,
 "per_device_eval_batch_size": 32,
 "output_dir": "output_dir",
 "do_train": True,
 "do_eval": True,
 "fp16": True,
 "learning_rate": 1e-5,
 "load_best_model_at_end": True,
 "evaluation_strategy": "epoch",
 "save_strategy": "epoch",
 "report_to": "none"
}

seq2seq_training_args = Seq2SeqTrainingArguments(**training_args)

In [None]:
feature_extractor = ViTFeatureExtractor.from_pretrained(
 model_args.encoder_model_name_or_path
)
tokenizer = AutoTokenizer.from_pretrained(
 model_args.decoder_model_name_or_path, use_fast=True
)
tokenizer.pad_token = tokenizer.eos_token

model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(
 model_args.encoder_model_name_or_path, model_args.decoder_model_name_or_path
)

# set special tokens used for creating the decoder_input_ids from the labels
model.config.decoder_start_token_id = tokenizer.bos_token_id
model.config.pad_token_id = tokenizer.pad_token_id
# make sure vocab size is set correctly
model.config.vocab_size = model.config.decoder.vocab_size

# set beam search parameters
model.config.eos_token_id = tokenizer.sep_token_id
model.config.max_length = data_training_args.max_target_length
model.config.no_repeat_ngram_size = model_args.no_repeat_ngram_size
model.config.length_penalty = model_args.length_penalty
model.config.num_beams = data_training_args.num_beams
model.decoder.resize_token_embeddings(len(tokenizer))


In [None]:
train_df = pd.read_csv(data_dir / "train.csv")
valid_df = pd.read_csv(data_dir / "valid.csv")

train_dataset = ImageCaptionDataset(
 df=train_df,
 feature_extractor=feature_extractor,
 tokenizer=tokenizer,
 images_dir=data_training_args.images_dir,
 max_target_length=data_training_args.max_target_length,
)
eval_dataset = ImageCaptionDataset(
 df=valid_df,
 feature_extractor=feature_extractor,
 tokenizer=tokenizer,
 images_dir=data_training_args.images_dir,
 max_target_length=data_training_args.max_target_length,
)

print(f"Number of training examples: {len(train_dataset)}")
print(f"Number of validation examples: {len(eval_dataset)}")

In [None]:
# Let's verify an example from the training dataset:

encoding = train_dataset[0]
for k,v in encoding.items():
 print(k, v.shape)

In [None]:
# We can also check the original image and decode the labels:
image = Image.open(data_training_args.images_dir / train_df["filename"][0]).convert("RGB")
image

In [None]:
labels = encoding["labels"]
labels[labels == -100] = tokenizer.pad_token_id
label_str = tokenizer.decode(labels, skip_special_tokens=True)
print(label_str)


In [None]:
optimizer = AdamW(model.parameters(), lr=seq2seq_training_args.learning_rate)

steps_per_epoch = len(train_dataset) // seq2seq_training_args.per_device_train_batch_size
num_training_steps = steps_per_epoch * seq2seq_training_args.num_train_epochs

lr_scheduler = get_linear_schedule_with_warmup(
 optimizer,
 num_warmup_steps=seq2seq_training_args.warmup_steps,
 num_training_steps=num_training_steps,
)

optimizers = (optimizer, lr_scheduler)

In [None]:
trainer = Seq2SeqTrainer(
 model=model,
 optimizers=optimizers,
 tokenizer=feature_extractor,
 args=seq2seq_training_args,
 train_dataset=train_dataset,
 eval_dataset=eval_dataset,
 data_collator=default_data_collator,
)

trainer.train()

In [None]:
test_img = "../examples/tt7991608-red-notice.jpg"
with Image.open(test_img) as image:
 preds = predict(
 image, max_length=data_training_args.max_target_length, num_beams=data_training_args.num_beams
 )

# Uncomment to display the test image in a jupyter notebook
# display(image)
print(preds[0])