{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "0fbed7bc", "metadata": { "ExecuteTime": { "end_time": "2021-12-09T15:34:14.921553Z", "start_time": "2021-12-09T15:34:14.911112Z" } }, "outputs": [], "source": [ "%reload_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "code", "execution_count": null, "id": "c4b60ef3", "metadata": { "ExecuteTime": { "end_time": "2021-12-09T15:34:15.961098Z", "start_time": "2021-12-09T15:34:14.922771Z" }, "code_folding": [] }, "outputs": [], "source": [ "# imports\n", "\n", "import pandas as pd\n", "import os\n", "from pathlib import Path\n", "from PIL import Image\n", "import shutil\n", "from logging import root\n", "from PIL import Image\n", "from pathlib import Path\n", "import pandas as pd\n", "import torch\n", "from torch.utils.data import Dataset\n", "from PIL import Image\n", "from transformers import (\n", " Seq2SeqTrainer,\n", " Seq2SeqTrainingArguments,\n", " get_linear_schedule_with_warmup,\n", " AutoFeatureExtractor,\n", " AutoTokenizer,\n", " ViTFeatureExtractor,\n", " VisionEncoderDecoderModel,\n", " default_data_collator,\n", ")\n", "from transformers.optimization import AdamW\n", "\n", "from box import Box\n", "import inspect\n" ] }, { "cell_type": "code", "execution_count": null, "id": "99d6f14d", "metadata": { "ExecuteTime": { "end_time": "2021-12-09T15:34:15.979191Z", "start_time": "2021-12-09T15:34:15.962078Z" }, "code_folding": [] }, "outputs": [], "source": [ "# custom functions\n", "\n", "class ImageCaptionDataset(Dataset):\n", " def __init__(\n", " self, df, feature_extractor, tokenizer, images_dir, max_target_length=128\n", " ):\n", " self.df = df\n", " self.feature_extractor = feature_extractor\n", " self.tokenizer = tokenizer\n", " self.images_dir = images_dir\n", " self.max_target_length = max_target_length\n", "\n", " def __len__(self):\n", " return len(self.df)\n", "\n", " def __getitem__(self, idx):\n", " filename = self.df[\"filename\"][idx]\n", " text = self.df[\"text\"][idx]\n", " # prepare image (i.e. resize + normalize)\n", " image = Image.open(self.images_dir / filename).convert(\"RGB\")\n", " pixel_values = self.feature_extractor(image, return_tensors=\"pt\").pixel_values\n", " # add labels (input_ids) by encoding the text\n", " labels = self.tokenizer(\n", " text,\n", " padding=\"max_length\",\n", " truncation=True,\n", " max_length=self.max_target_length,\n", " ).input_ids\n", " # important: make sure that PAD tokens are ignored by the loss function\n", " labels = [\n", " label if label != self.tokenizer.pad_token_id else -100 for label in labels\n", " ]\n", "\n", " encoding = {\n", " \"pixel_values\": pixel_values.squeeze(),\n", " \"labels\": torch.tensor(labels),\n", " }\n", " return encoding\n", "\n", "\n", "\n", "def predict(image, max_length=64, num_beams=4):\n", "\n", " pixel_values = feature_extractor(images=image, return_tensors=\"pt\").pixel_values\n", " pixel_values = pixel_values.to(device)\n", "\n", " with torch.no_grad():\n", " output_ids = model.generate(\n", " pixel_values,\n", " max_length=max_length,\n", " num_beams=num_beams,\n", " return_dict_in_generate=True,\n", " ).sequences\n", "\n", " preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)\n", " preds = [pred.strip() for pred in preds]\n", "\n", " return preds\n" ] }, { "cell_type": "code", "execution_count": null, "id": "ea66826b", "metadata": { "ExecuteTime": { "end_time": "2021-12-09T15:34:16.042990Z", "start_time": "2021-12-09T15:34:15.980557Z" } }, "outputs": [], "source": [ "data_dir = Path(\"datasets\").resolve()\n", "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", "print(device)" ] }, { "cell_type": "code", "execution_count": null, "id": "17cfb2c2", "metadata": { "ExecuteTime": { "end_time": "2021-12-09T15:34:16.058421Z", "start_time": "2021-12-09T15:34:16.044111Z" } }, "outputs": [], "source": [ "# arguments pertaining to what data we are going to input our model for training and eval.\n", "\n", "data_training_args = {\n", " # The maximum total sequence length for target text after tokenization. Sequences longer than this will be truncated, sequences shorter will be padded.\n", " \"max_target_length\": 64,\n", "\n", " # Number of beams to use for evaluation. This argument will be passed to model.generate which is used during evaluate and predict.\n", " \"num_beams\": 4,\n", "\n", " # Folder with all the images\n", " \"images_dir\": data_dir / \"images\",\n", "}\n", "\n", "data_training_args = Box(data_training_args)" ] }, { "cell_type": "code", "execution_count": null, "id": "adc4839a", "metadata": { "ExecuteTime": { "end_time": "2021-12-09T15:34:16.073242Z", "start_time": "2021-12-09T15:34:16.059354Z" } }, "outputs": [], "source": [ "# arguments pertaining to which model/config/tokenizer we are going to fine-tune from.\n", "\n", "model_args = {\n", "\n", " # Path to pretrained model or model identifier from huggingface.co/models\"\n", " \"encoder_model_name_or_path\": \"google/vit-base-patch16-224-in21k\",\n", "\n", " # Path to pretrained model or model identifier from huggingface.co/models\"\n", " \"decoder_model_name_or_path\": \"gpt2\",\n", "\n", " # If set to int > 0, all ngrams of that size can only occur once.\n", " \"no_repeat_ngram_size\": 3,\n", "\n", " # Exponential penalty to the length that will be used by default in the generate method of the model.\n", " \"length_penalty\": 2.0,\n", "}\n", "\n", "model_args = Box(model_args)" ] }, { "cell_type": "code", "execution_count": null, "id": "22b8c9e3", "metadata": { "ExecuteTime": { "end_time": "2021-12-09T15:34:16.089201Z", "start_time": "2021-12-09T15:34:16.074223Z" } }, "outputs": [], "source": [ "# arguments pertaining to Trainer class. Refer: https://huggingface.co/docs/transformers/main_classes/trainer#transformers.Seq2SeqTrainingArguments\n", "\n", "training_args = {\n", " \"num_train_epochs\": 5,\n", " \"per_device_train_batch_size\": 32,\n", " \"per_device_eval_batch_size\": 32,\n", " \"output_dir\": \"output_dir\",\n", " \"do_train\": True,\n", " \"do_eval\": True,\n", " \"fp16\": True,\n", " \"learning_rate\": 1e-5,\n", " \"load_best_model_at_end\": True,\n", " \"evaluation_strategy\": \"epoch\",\n", " \"save_strategy\": \"epoch\",\n", " \"report_to\": \"none\"\n", "}\n", "\n", "seq2seq_training_args = Seq2SeqTrainingArguments(**training_args)" ] }, { "cell_type": "code", "execution_count": null, "id": "d0023eac", "metadata": { "ExecuteTime": { "end_time": "2021-12-09T15:34:37.844396Z", "start_time": "2021-12-09T15:34:16.090085Z" } }, "outputs": [], "source": [ "feature_extractor = ViTFeatureExtractor.from_pretrained(\n", " model_args.encoder_model_name_or_path\n", ")\n", "tokenizer = AutoTokenizer.from_pretrained(\n", " model_args.decoder_model_name_or_path, use_fast=True\n", ")\n", "tokenizer.pad_token = tokenizer.eos_token\n", "\n", "model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(\n", " model_args.encoder_model_name_or_path, model_args.decoder_model_name_or_path\n", ")\n", "\n", "# set special tokens used for creating the decoder_input_ids from the labels\n", "model.config.decoder_start_token_id = tokenizer.bos_token_id\n", "model.config.pad_token_id = tokenizer.pad_token_id\n", "# make sure vocab size is set correctly\n", "model.config.vocab_size = model.config.decoder.vocab_size\n", "\n", "# set beam search parameters\n", "model.config.eos_token_id = tokenizer.sep_token_id\n", "model.config.max_length = data_training_args.max_target_length\n", "model.config.no_repeat_ngram_size = model_args.no_repeat_ngram_size\n", "model.config.length_penalty = model_args.length_penalty\n", "model.config.num_beams = data_training_args.num_beams\n", "model.decoder.resize_token_embeddings(len(tokenizer))\n" ] }, { "cell_type": "code", "execution_count": null, "id": "6428ea08", "metadata": { "ExecuteTime": { "end_time": "2021-12-09T15:34:37.933804Z", "start_time": "2021-12-09T15:34:37.845607Z" } }, "outputs": [], "source": [ "train_df = pd.read_csv(data_dir / \"train.csv\")\n", "valid_df = pd.read_csv(data_dir / \"valid.csv\")\n", "\n", "train_dataset = ImageCaptionDataset(\n", " df=train_df,\n", " feature_extractor=feature_extractor,\n", " tokenizer=tokenizer,\n", " images_dir=data_training_args.images_dir,\n", " max_target_length=data_training_args.max_target_length,\n", ")\n", "eval_dataset = ImageCaptionDataset(\n", " df=valid_df,\n", " feature_extractor=feature_extractor,\n", " tokenizer=tokenizer,\n", " images_dir=data_training_args.images_dir,\n", " max_target_length=data_training_args.max_target_length,\n", ")\n", "\n", "print(f\"Number of training examples: {len(train_dataset)}\")\n", "print(f\"Number of validation examples: {len(eval_dataset)}\")" ] }, { "cell_type": "code", "execution_count": null, "id": "c8e492a1", "metadata": { "ExecuteTime": { "end_time": "2021-12-09T15:34:37.971630Z", "start_time": "2021-12-09T15:34:37.935339Z" } }, "outputs": [], "source": [ "# Let's verify an example from the training dataset:\n", "\n", "encoding = train_dataset[0]\n", "for k,v in encoding.items():\n", " print(k, v.shape)" ] }, { "cell_type": "code", "execution_count": null, "id": "edb4e7a6", "metadata": { "ExecuteTime": { "end_time": "2021-12-09T15:34:38.006980Z", "start_time": "2021-12-09T15:34:37.972483Z" } }, "outputs": [], "source": [ "# We can also check the original image and decode the labels:\n", "image = Image.open(data_training_args.images_dir / train_df[\"filename\"][0]).convert(\"RGB\")\n", "image" ] }, { "cell_type": "code", "execution_count": null, "id": "25f2cae7", "metadata": { "ExecuteTime": { "end_time": "2021-12-09T15:34:38.031745Z", "start_time": "2021-12-09T15:34:38.008027Z" } }, "outputs": [], "source": [ "labels = encoding[\"labels\"]\n", "labels[labels == -100] = tokenizer.pad_token_id\n", "label_str = tokenizer.decode(labels, skip_special_tokens=True)\n", "print(label_str)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "b7a009d3", "metadata": { "ExecuteTime": { "end_time": "2021-12-09T15:34:38.049539Z", "start_time": "2021-12-09T15:34:38.032749Z" } }, "outputs": [], "source": [ "optimizer = AdamW(model.parameters(), lr=seq2seq_training_args.learning_rate)\n", "\n", "steps_per_epoch = len(train_dataset) // seq2seq_training_args.per_device_train_batch_size\n", "num_training_steps = steps_per_epoch * seq2seq_training_args.num_train_epochs\n", "\n", "lr_scheduler = get_linear_schedule_with_warmup(\n", " optimizer,\n", " num_warmup_steps=seq2seq_training_args.warmup_steps,\n", " num_training_steps=num_training_steps,\n", ")\n", "\n", "optimizers = (optimizer, lr_scheduler)" ] }, { "cell_type": "code", "execution_count": null, "id": "f2f477b2", "metadata": { "ExecuteTime": { "start_time": "2021-12-09T15:34:14.944Z" } }, "outputs": [], "source": [ "trainer = Seq2SeqTrainer(\n", " model=model,\n", " optimizers=optimizers,\n", " tokenizer=feature_extractor,\n", " args=seq2seq_training_args,\n", " train_dataset=train_dataset,\n", " eval_dataset=eval_dataset,\n", " data_collator=default_data_collator,\n", ")\n", "\n", "trainer.train()" ] }, { "cell_type": "code", "execution_count": null, "id": "f08d2b7c", "metadata": { "ExecuteTime": { "end_time": "2021-12-09T16:24:49.096274Z", "start_time": "2021-12-09T16:24:49.096246Z" } }, "outputs": [], "source": [ "test_img = \"../examples/tt7991608-red-notice.jpg\"\n", "with Image.open(test_img) as image:\n", " preds = predict(\n", " image, max_length=data_training_args.max_target_length, num_beams=data_training_args.num_beams\n", " )\n", "\n", "# Uncomment to display the test image in a jupyter notebook\n", "# display(image)\n", "print(preds[0])" ] }, { "cell_type": "code", "execution_count": null, "id": "ecf21225", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "huggingface", "language": "python", "name": "huggingface" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.7" } }, "nbformat": 4, "nbformat_minor": 5 }