jerome commited on
Commit
a0446c6
1 Parent(s): c671ac3

initial commit

Browse files
Files changed (2) hide show
  1. README.md +1 -0
  2. flava_finetuning_tutorial.ipynb +325 -0
README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ FLAVA fine-tuning
flava_finetuning_tutorial.ipynb ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {
7
+ "id": "oUL6DV1zCIlB"
8
+ },
9
+ "outputs": [],
10
+ "source": [
11
+ "%matplotlib inline\n",
12
+ "!nvidia-smi"
13
+ ]
14
+ },
15
+ {
16
+ "cell_type": "markdown",
17
+ "metadata": {
18
+ "id": "WmJySTGXCIlD"
19
+ },
20
+ "source": [
21
+ "\n",
22
+ "# TorchMultimodal Tutorial: Finetuning FLAVA\n"
23
+ ]
24
+ },
25
+ {
26
+ "cell_type": "markdown",
27
+ "metadata": {
28
+ "id": "ZJCb2uRyCIlE"
29
+ },
30
+ "source": [
31
+ "Multimodal AI has recently become very popular owing to its ubiquitous\n",
32
+ "nature, from use cases like image captioning and visual search to more\n",
33
+ "recent applications like image generation from text. **TorchMultimodal\n",
34
+ "is a library powered by Pytorch consisting of building blocks and end to\n",
35
+ "end examples, aiming to enable and accelerate research in\n",
36
+ "multimodality**.\n",
37
+ "\n",
38
+ "In this tutorial, we will demonstrate how to use a **pretrained SoTA\n",
39
+ "model called** [FLAVA](https://arxiv.org/pdf/2112.04482.pdf)_ **from\n",
40
+ "TorchMultimodal library to finetune on a multimodal task i.e. visual\n",
41
+ "question answering** (VQA). The model consists of two unimodal transformer\n",
42
+ "based encoders for text and image and a multimodal encoder to combine\n",
43
+ "the two embeddings. It is pretrained using contrastive, image text matching and \n",
44
+ "text, image and multimodal masking losses.\n",
45
+ "\n"
46
+ ]
47
+ },
48
+ {
49
+ "cell_type": "markdown",
50
+ "metadata": {
51
+ "id": "0TjU3iQgCIlE"
52
+ },
53
+ "source": [
54
+ "## Installation\n",
55
+ "We will use TextVQA dataset and bert tokenizer from HuggingFace for this\n",
56
+ "tutorial. So you need to install datasets and transformers in addition to TorchMultimodal.\n",
57
+ "\n",
58
+ "<div class=\"alert alert-info\"><h4>Note</h4><p>When running this tutorial in Google Colab, install the required packages by\n",
59
+ " creating a new cell and running the following commands:\n",
60
+ "\n",
61
+ "```\n",
62
+ "!pip install torchmultimodal-nightly\n",
63
+ "!pip install datasets\n",
64
+ "!pip install transformers</p></div>\n",
65
+ "```\n"
66
+ ]
67
+ },
68
+ {
69
+ "cell_type": "markdown",
70
+ "metadata": {
71
+ "id": "LGuYfyaJCIlE"
72
+ },
73
+ "source": [
74
+ "## Steps \n",
75
+ "\n",
76
+ "1. Download the HuggingFace dataset to a directory on your computer by running the following command:\n",
77
+ "\n",
78
+ "```\n",
79
+ "wget http://dl.fbaipublicfiles.com/pythia/data/vocab.tar.gz \n",
80
+ "tar xf vocab.tar.gz\n",
81
+ "```\n",
82
+ " .. note:: \n",
83
+ " If you are running this tutorial in Google Colab, run these commands\n",
84
+ " in a new cell and prepend these commands with an exclamation mark (!)\n",
85
+ "\n",
86
+ "\n",
87
+ "2. For this tutorial, we treat VQA as a classification task where\n",
88
+ " the inputs are images and question (text) and the output is an answer class. \n",
89
+ " So we need to download the vocab file with answer classes and create the answer to\n",
90
+ " label mapping.\n",
91
+ "\n",
92
+ " We also load the [textvqa\n",
93
+ " dataset](https://arxiv.org/pdf/1904.08920.pdf)_ containing 34602 training samples\n",
94
+ " (images,questions and answers) from HuggingFace\n",
95
+ "\n",
96
+ "We see there are 3997 answer classes including a class representing\n",
97
+ "unknown answers.\n",
98
+ "\n",
99
+ "\n"
100
+ ]
101
+ },
102
+ {
103
+ "cell_type": "code",
104
+ "execution_count": null,
105
+ "metadata": {
106
+ "id": "b6c1oq0lCIlF"
107
+ },
108
+ "outputs": [],
109
+ "source": [
110
+ "with open(\"data/vocabs/answers_textvqa_more_than_1.txt\") as f:\n",
111
+ " vocab = f.readlines()\n",
112
+ "\n",
113
+ "answer_to_idx = {}\n",
114
+ "for idx, entry in enumerate(vocab):\n",
115
+ " answer_to_idx[entry.strip(\"\\n\")] = idx\n",
116
+ "print(len(vocab))\n",
117
+ "print(vocab[:5])\n",
118
+ "\n",
119
+ "from datasets import load_dataset\n",
120
+ "dataset = load_dataset(\"textvqa\")"
121
+ ]
122
+ },
123
+ {
124
+ "cell_type": "markdown",
125
+ "metadata": {
126
+ "id": "kGCla9GgCIlF"
127
+ },
128
+ "source": [
129
+ "Lets display a sample entry from the dataset:\n",
130
+ "\n",
131
+ "\n"
132
+ ]
133
+ },
134
+ {
135
+ "cell_type": "code",
136
+ "execution_count": null,
137
+ "metadata": {
138
+ "id": "GLS8HGYtCIlF"
139
+ },
140
+ "outputs": [],
141
+ "source": [
142
+ "import matplotlib.pyplot as plt\n",
143
+ "import numpy as np \n",
144
+ "idx = 5 \n",
145
+ "print(\"Question: \", dataset[\"train\"][idx][\"question\"]) \n",
146
+ "print(\"Answers: \" ,dataset[\"train\"][idx][\"answers\"])\n",
147
+ "im = np.asarray(dataset[\"train\"][idx][\"image\"].resize((500,500)))\n",
148
+ "plt.imshow(im)\n",
149
+ "plt.show()"
150
+ ]
151
+ },
152
+ {
153
+ "cell_type": "markdown",
154
+ "metadata": {
155
+ "id": "J1UO_daoCIlG"
156
+ },
157
+ "source": [
158
+ "3. Next, we write the transform function to convert the image and text into\n",
159
+ "Tensors consumable by our model - For images, we use the transforms from\n",
160
+ "torchvision to convert to Tensor and resize to uniform sizes - For text,\n",
161
+ "we tokenize (and pad) them using the BertTokenizer from HuggingFace -\n",
162
+ "For answers (i.e. labels), we take the most frequently occuring answer\n",
163
+ "as the label to train with:\n",
164
+ "\n",
165
+ "\n"
166
+ ]
167
+ },
168
+ {
169
+ "cell_type": "code",
170
+ "execution_count": null,
171
+ "metadata": {
172
+ "id": "rO7lCn4DCIlG"
173
+ },
174
+ "outputs": [],
175
+ "source": [
176
+ "import torch\n",
177
+ "from torchvision import transforms\n",
178
+ "from collections import defaultdict\n",
179
+ "from transformers import BertTokenizer\n",
180
+ "from functools import partial\n",
181
+ "\n",
182
+ "def transform(tokenizer, input):\n",
183
+ " batch = {}\n",
184
+ " image_transform = transforms.Compose([transforms.ToTensor(), transforms.Resize([224,224])])\n",
185
+ " image = image_transform(input[\"image\"][0].convert(\"RGB\"))\n",
186
+ " batch[\"image\"] = [image]\n",
187
+ "\n",
188
+ " tokenized=tokenizer(input[\"question\"],return_tensors='pt',padding=\"max_length\",max_length=512)\n",
189
+ " batch.update(tokenized)\n",
190
+ "\n",
191
+ "\n",
192
+ " ans_to_count = defaultdict(int)\n",
193
+ " for ans in input[\"answers\"][0]:\n",
194
+ " ans_to_count[ans] += 1\n",
195
+ " max_value = max(ans_to_count, key=ans_to_count.get)\n",
196
+ " ans_idx = answer_to_idx.get(max_value,0)\n",
197
+ " batch[\"answers\"] = torch.as_tensor([ans_idx])\n",
198
+ " return batch\n",
199
+ "\n",
200
+ "tokenizer=BertTokenizer.from_pretrained(\"bert-base-uncased\",padding=\"max_length\",max_length=512)\n",
201
+ "transform=partial(transform,tokenizer)\n",
202
+ "dataset.set_transform(transform)"
203
+ ]
204
+ },
205
+ {
206
+ "cell_type": "markdown",
207
+ "metadata": {
208
+ "id": "LOMy3UbpCIlG"
209
+ },
210
+ "source": [
211
+ "4. Finally, we import the flava_model_for_classification from\n",
212
+ "torchmultimodal. It loads the pretrained flava checkpoint by default and\n",
213
+ "includes a classification head.\n",
214
+ "\n",
215
+ "The model forward function passes the image through the visual encoder\n",
216
+ "and the question through the text encoder. The image and question\n",
217
+ "embeddings are then passed through the multimodal encoder. The final\n",
218
+ "embedding corresponding to the CLS token is passed through a MLP head\n",
219
+ "which finally gives the probability distribution over each possible\n",
220
+ "answers.\n",
221
+ "\n",
222
+ "\n"
223
+ ]
224
+ },
225
+ {
226
+ "cell_type": "code",
227
+ "execution_count": null,
228
+ "metadata": {
229
+ "id": "drSfcYNCCIlG"
230
+ },
231
+ "outputs": [],
232
+ "source": [
233
+ "from torchmultimodal.models.flava.model import flava_model_for_classification\n",
234
+ "model = flava_model_for_classification(num_classes=len(vocab))"
235
+ ]
236
+ },
237
+ {
238
+ "cell_type": "markdown",
239
+ "metadata": {
240
+ "id": "976mlWvaCIlG"
241
+ },
242
+ "source": [
243
+ "5. We put together the dataset and model in a toy training loop to\n",
244
+ "demonstrate how to train the model for 3 iterations:\n",
245
+ "\n",
246
+ "\n"
247
+ ]
248
+ },
249
+ {
250
+ "cell_type": "code",
251
+ "execution_count": null,
252
+ "metadata": {
253
+ "id": "0KvxQ4xaCIlH"
254
+ },
255
+ "outputs": [],
256
+ "source": [
257
+ "from torch import nn\n",
258
+ "BATCH_SIZE = 2\n",
259
+ "MAX_STEPS = 3\n",
260
+ "from torch.utils.data import DataLoader\n",
261
+ "\n",
262
+ "train_dataloader = DataLoader(dataset[\"train\"], batch_size= BATCH_SIZE)\n",
263
+ "optimizer = torch.optim.AdamW(model.parameters())\n",
264
+ "\n",
265
+ "\n",
266
+ "epochs = 1\n",
267
+ "for _ in range(epochs):\n",
268
+ " for idx, batch in enumerate(train_dataloader):\n",
269
+ " optimizer.zero_grad()\n",
270
+ " out = model(text = batch[\"input_ids\"], image = batch[\"image\"], labels = batch[\"answers\"])\n",
271
+ " loss = out.loss\n",
272
+ " loss.backward()\n",
273
+ " optimizer.step()\n",
274
+ " print(f\"Loss at step {idx} = {loss}\")\n",
275
+ " if idx > MAX_STEPS-1:\n",
276
+ " break"
277
+ ]
278
+ },
279
+ {
280
+ "cell_type": "markdown",
281
+ "metadata": {
282
+ "id": "A7An1sjZCIlH"
283
+ },
284
+ "source": [
285
+ "## Conclusion\n",
286
+ "\n",
287
+ "This tutorial introduced the basics around how to finetune on a\n",
288
+ "multimodal task using FLAVA from TorchMultimodal. Please also check out\n",
289
+ "other examples from the library like\n",
290
+ "[MDETR](https://github.com/facebookresearch/multimodal/tree/main/torchmultimodal/models/mdetr)_\n",
291
+ "which is a multimodal model for object detection and\n",
292
+ "[Omnivore](https://github.com/facebookresearch/multimodal/blob/main/torchmultimodal/models/omnivore.py)_\n",
293
+ "which is multitask model spanning image, video and 3d classification.\n",
294
+ "\n",
295
+ "\n"
296
+ ]
297
+ }
298
+ ],
299
+ "metadata": {
300
+ "kernelspec": {
301
+ "display_name": "Python 3",
302
+ "language": "python",
303
+ "name": "python3"
304
+ },
305
+ "language_info": {
306
+ "codemirror_mode": {
307
+ "name": "ipython",
308
+ "version": 3
309
+ },
310
+ "file_extension": ".py",
311
+ "mimetype": "text/x-python",
312
+ "name": "python",
313
+ "nbconvert_exporter": "python",
314
+ "pygments_lexer": "ipython3",
315
+ "version": "3.10.9"
316
+ },
317
+ "colab": {
318
+ "provenance": []
319
+ },
320
+ "accelerator": "GPU",
321
+ "gpuClass": "standard"
322
+ },
323
+ "nbformat": 4,
324
+ "nbformat_minor": 0
325
+ }