dsr commited on
Commit
05293b5
1 Parent(s): d44c51f

Add dataset creation and model training code

Browse files
.gitignore CHANGED
@@ -1,3 +1,5 @@
1
  .vscode
2
  .ipynb_checkpoints
3
  .idea
 
 
1
  .vscode
2
  .ipynb_checkpoints
3
  .idea
4
+ datasets
5
+ output_dir
train/README.md ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
1
+ # Train new model
2
+
3
+ - Download and extract the following datasets in a new folder called datasets:
4
+
5
+ 1. [IMDb movies extensive dataset](https://www.kaggle.com/stefanoleone992/imdb-extensive-dataset)
6
+ 2. [48K IMDB Movies With Posters](https://www.kaggle.com/rezaunderfit/48k-imdb-movies-with-posters)
7
+
8
+ - Run `create_dataset.ipynb` to create train.csv and valid.csv
9
+ - Run `train.ipynb` to train the model
train/create_dataset.ipynb ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "0fbed7bc",
7
+ "metadata": {
8
+ "ExecuteTime": {
9
+ "end_time": "2021-12-09T16:46:29.851016Z",
10
+ "start_time": "2021-12-09T16:46:29.841794Z"
11
+ }
12
+ },
13
+ "outputs": [],
14
+ "source": [
15
+ "%reload_ext autoreload\n",
16
+ "%autoreload 2"
17
+ ]
18
+ },
19
+ {
20
+ "cell_type": "code",
21
+ "execution_count": null,
22
+ "id": "99d6f14d",
23
+ "metadata": {
24
+ "ExecuteTime": {
25
+ "end_time": "2021-12-09T16:46:30.336104Z",
26
+ "start_time": "2021-12-09T16:46:29.852308Z"
27
+ }
28
+ },
29
+ "outputs": [],
30
+ "source": [
31
+ "from pathlib import Path\n",
32
+ "import pandas as pd\n",
33
+ "import shutil\n",
34
+ "from sklearn.model_selection import train_test_split"
35
+ ]
36
+ },
37
+ {
38
+ "cell_type": "code",
39
+ "execution_count": null,
40
+ "id": "c8fcf96c",
41
+ "metadata": {
42
+ "ExecuteTime": {
43
+ "end_time": "2021-12-09T16:46:30.349125Z",
44
+ "start_time": "2021-12-09T16:46:30.337223Z"
45
+ },
46
+ "code_folding": []
47
+ },
48
+ "outputs": [],
49
+ "source": [
50
+ "def copy_images(\n",
51
+ " src_dir: Path,\n",
52
+ " des_dir: Path,\n",
53
+ " ids_with_plots: list,\n",
54
+ " delete_existing_files: bool = False,\n",
55
+ "):\n",
56
+ " \"\"\"This function copies a poster to images folder if it's id is present in the ids_with_plots list\"\"\"\n",
57
+ "\n",
58
+ " images_list = []\n",
59
+ " if delete_existing_files:\n",
60
+ " shutil.rmtree(des_dir)\n",
61
+ "\n",
62
+ " des_dir.mkdir(parents=True, exist_ok=True)\n",
63
+ "\n",
64
+ " for f in src_dir.rglob(\"*\"):\n",
65
+ " try:\n",
66
+ " if f.is_file() and f.suffix in [\".jpg\", \".jpeg\", \".png\"]:\n",
67
+ " img_name = f.name\n",
68
+ " id = Path(img_name).stem\n",
69
+ " if id in ids_with_plots:\n",
70
+ " desc_file = des_dir / img_name\n",
71
+ " shutil.copy(f, desc_file)\n",
72
+ " images_list.append((id, img_name))\n",
73
+ " except Exception as e:\n",
74
+ " print(f, e)\n",
75
+ " return images_list"
76
+ ]
77
+ },
78
+ {
79
+ "cell_type": "code",
80
+ "execution_count": null,
81
+ "id": "a34124b2",
82
+ "metadata": {
83
+ "ExecuteTime": {
84
+ "end_time": "2021-12-09T16:46:30.359361Z",
85
+ "start_time": "2021-12-09T16:46:30.350299Z"
86
+ }
87
+ },
88
+ "outputs": [],
89
+ "source": [
90
+ "data_dir = Path(\"datasets\").resolve()\n",
91
+ "images_dir = data_dir / \"images\""
92
+ ]
93
+ },
94
+ {
95
+ "cell_type": "code",
96
+ "execution_count": null,
97
+ "id": "8714ea01",
98
+ "metadata": {
99
+ "ExecuteTime": {
100
+ "end_time": "2021-12-09T16:46:30.781046Z",
101
+ "start_time": "2021-12-09T16:46:30.360608Z"
102
+ }
103
+ },
104
+ "outputs": [],
105
+ "source": [
106
+ "movies_df = pd.read_csv(\n",
107
+ " data_dir / \"IMDb movies.csv\", usecols=[\"imdb_title_id\", \"description\"]\n",
108
+ ")\n",
109
+ "movies_df = movies_df.rename(columns={\"imdb_title_id\": \"id\", \"description\": \"text\"})\n",
110
+ "movies_df.dropna(subset=[\"text\"], inplace=True) # Drop rows where text is empty\n",
111
+ "movies_df.head()\n"
112
+ ]
113
+ },
114
+ {
115
+ "cell_type": "code",
116
+ "execution_count": null,
117
+ "id": "27f7fd94",
118
+ "metadata": {
119
+ "ExecuteTime": {
120
+ "end_time": "2021-12-09T16:46:30.792761Z",
121
+ "start_time": "2021-12-09T16:46:30.781964Z"
122
+ }
123
+ },
124
+ "outputs": [],
125
+ "source": [
126
+ "ids_with_plots = movies_df.id.tolist()"
127
+ ]
128
+ },
129
+ {
130
+ "cell_type": "code",
131
+ "execution_count": null,
132
+ "id": "ebaa042a",
133
+ "metadata": {
134
+ "ExecuteTime": {
135
+ "end_time": "2021-12-09T16:47:04.704390Z",
136
+ "start_time": "2021-12-09T16:46:30.794094Z"
137
+ }
138
+ },
139
+ "outputs": [],
140
+ "source": [
141
+ "images_list = copy_images(data_dir / \"Poster\", images_dir, ids_with_plots)\n",
142
+ "images_list[0]"
143
+ ]
144
+ },
145
+ {
146
+ "cell_type": "code",
147
+ "execution_count": null,
148
+ "id": "17e0a874",
149
+ "metadata": {
150
+ "ExecuteTime": {
151
+ "end_time": "2021-12-09T16:47:04.724427Z",
152
+ "start_time": "2021-12-09T16:47:04.705540Z"
153
+ }
154
+ },
155
+ "outputs": [],
156
+ "source": [
157
+ "images_df = pd.DataFrame(images_list, columns=[\"id\", \"filename\"])\n",
158
+ "images_df.head()"
159
+ ]
160
+ },
161
+ {
162
+ "cell_type": "code",
163
+ "execution_count": null,
164
+ "id": "bb1114e6",
165
+ "metadata": {
166
+ "ExecuteTime": {
167
+ "end_time": "2021-12-09T16:47:04.772775Z",
168
+ "start_time": "2021-12-09T16:47:04.725707Z"
169
+ }
170
+ },
171
+ "outputs": [],
172
+ "source": [
173
+ "data_df = pd.merge(movies_df, images_df, on=[\"id\"])\n",
174
+ "print(len(data_df))\n",
175
+ "data_df"
176
+ ]
177
+ },
178
+ {
179
+ "cell_type": "code",
180
+ "execution_count": null,
181
+ "id": "6790815b",
182
+ "metadata": {
183
+ "ExecuteTime": {
184
+ "end_time": "2021-12-09T16:47:04.796785Z",
185
+ "start_time": "2021-12-09T16:47:04.774932Z"
186
+ }
187
+ },
188
+ "outputs": [],
189
+ "source": [
190
+ "print(len(data_df))\n",
191
+ "data_df.dropna(subset=[\"filename\"], inplace=True)\n",
192
+ "print(len(data_df))"
193
+ ]
194
+ },
195
+ {
196
+ "cell_type": "code",
197
+ "execution_count": null,
198
+ "id": "40c7205d",
199
+ "metadata": {
200
+ "ExecuteTime": {
201
+ "end_time": "2021-12-09T16:47:04.818522Z",
202
+ "start_time": "2021-12-09T16:47:04.798063Z"
203
+ }
204
+ },
205
+ "outputs": [],
206
+ "source": [
207
+ "print(len(data_df))\n",
208
+ "data_df.dropna(subset=[\"text\"], inplace=True)\n",
209
+ "print(len(data_df))"
210
+ ]
211
+ },
212
+ {
213
+ "cell_type": "code",
214
+ "execution_count": null,
215
+ "id": "9a2d142f",
216
+ "metadata": {
217
+ "ExecuteTime": {
218
+ "end_time": "2021-12-09T16:47:04.838450Z",
219
+ "start_time": "2021-12-09T16:47:04.819726Z"
220
+ }
221
+ },
222
+ "outputs": [],
223
+ "source": [
224
+ "print(len(data_df))\n",
225
+ "data_df.drop_duplicates(subset=[\"id\"], inplace=True)\n",
226
+ "print(len(data_df))"
227
+ ]
228
+ },
229
+ {
230
+ "cell_type": "code",
231
+ "execution_count": null,
232
+ "id": "45f4b970",
233
+ "metadata": {
234
+ "ExecuteTime": {
235
+ "end_time": "2021-12-09T16:47:04.971652Z",
236
+ "start_time": "2021-12-09T16:47:04.839618Z"
237
+ }
238
+ },
239
+ "outputs": [],
240
+ "source": [
241
+ "data_df.to_csv(data_dir / \"data.csv\", index=False)"
242
+ ]
243
+ },
244
+ {
245
+ "cell_type": "code",
246
+ "execution_count": null,
247
+ "id": "f8019a02",
248
+ "metadata": {
249
+ "ExecuteTime": {
250
+ "end_time": "2021-12-09T16:47:05.104710Z",
251
+ "start_time": "2021-12-09T16:47:04.972681Z"
252
+ }
253
+ },
254
+ "outputs": [],
255
+ "source": [
256
+ "train_df, valid_df = train_test_split(data_df, test_size=0.1, shuffle=True)\n",
257
+ "train_df.to_csv(data_dir / \"train.csv\", index=False)\n",
258
+ "valid_df.to_csv(data_dir / \"valid.csv\", index=False)\n",
259
+ "print(len(train_df), len(valid_df))"
260
+ ]
261
+ }
262
+ ],
263
+ "metadata": {
264
+ "kernelspec": {
265
+ "display_name": "huggingface",
266
+ "language": "python",
267
+ "name": "huggingface"
268
+ },
269
+ "language_info": {
270
+ "codemirror_mode": {
271
+ "name": "ipython",
272
+ "version": 3
273
+ },
274
+ "file_extension": ".py",
275
+ "mimetype": "text/x-python",
276
+ "name": "python",
277
+ "nbconvert_exporter": "python",
278
+ "pygments_lexer": "ipython3",
279
+ "version": "3.9.7"
280
+ }
281
+ },
282
+ "nbformat": 4,
283
+ "nbformat_minor": 5
284
+ }
train/requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
1
+ --find-links https://download.pytorch.org/whl/torch_stable.html
2
+ pandas==1.3.4
3
+ scikit-learn==1.0.1
4
+ python-box==5.4.1
5
+ transformers==4.12.5
6
+ torch==1.10.0+cu113
7
+ Pillow==8.4.0
train/train.ipynb ADDED
@@ -0,0 +1,474 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "0fbed7bc",
7
+ "metadata": {
8
+ "ExecuteTime": {
9
+ "end_time": "2021-12-09T15:34:14.921553Z",
10
+ "start_time": "2021-12-09T15:34:14.911112Z"
11
+ }
12
+ },
13
+ "outputs": [],
14
+ "source": [
15
+ "%reload_ext autoreload\n",
16
+ "%autoreload 2"
17
+ ]
18
+ },
19
+ {
20
+ "cell_type": "code",
21
+ "execution_count": null,
22
+ "id": "c4b60ef3",
23
+ "metadata": {
24
+ "ExecuteTime": {
25
+ "end_time": "2021-12-09T15:34:15.961098Z",
26
+ "start_time": "2021-12-09T15:34:14.922771Z"
27
+ },
28
+ "code_folding": []
29
+ },
30
+ "outputs": [],
31
+ "source": [
32
+ "# imports\n",
33
+ "\n",
34
+ "import pandas as pd\n",
35
+ "import os\n",
36
+ "from pathlib import Path\n",
37
+ "from PIL import Image\n",
38
+ "import shutil\n",
39
+ "from logging import root\n",
40
+ "from PIL import Image\n",
41
+ "from pathlib import Path\n",
42
+ "import pandas as pd\n",
43
+ "import torch\n",
44
+ "from torch.utils.data import Dataset\n",
45
+ "from PIL import Image\n",
46
+ "from transformers import (\n",
47
+ " Seq2SeqTrainer,\n",
48
+ " Seq2SeqTrainingArguments,\n",
49
+ " get_linear_schedule_with_warmup,\n",
50
+ " AutoFeatureExtractor,\n",
51
+ " AutoTokenizer,\n",
52
+ " ViTFeatureExtractor,\n",
53
+ " VisionEncoderDecoderModel,\n",
54
+ " default_data_collator,\n",
55
+ ")\n",
56
+ "from transformers.optimization import AdamW\n",
57
+ "\n",
58
+ "from box import Box\n",
59
+ "import inspect\n"
60
+ ]
61
+ },
62
+ {
63
+ "cell_type": "code",
64
+ "execution_count": null,
65
+ "id": "99d6f14d",
66
+ "metadata": {
67
+ "ExecuteTime": {
68
+ "end_time": "2021-12-09T15:34:15.979191Z",
69
+ "start_time": "2021-12-09T15:34:15.962078Z"
70
+ },
71
+ "code_folding": []
72
+ },
73
+ "outputs": [],
74
+ "source": [
75
+ "# custom functions\n",
76
+ "\n",
77
+ "class ImageCaptionDataset(Dataset):\n",
78
+ " def __init__(\n",
79
+ " self, df, feature_extractor, tokenizer, images_dir, max_target_length=128\n",
80
+ " ):\n",
81
+ " self.df = df\n",
82
+ " self.feature_extractor = feature_extractor\n",
83
+ " self.tokenizer = tokenizer\n",
84
+ " self.images_dir = images_dir\n",
85
+ " self.max_target_length = max_target_length\n",
86
+ "\n",
87
+ " def __len__(self):\n",
88
+ " return len(self.df)\n",
89
+ "\n",
90
+ " def __getitem__(self, idx):\n",
91
+ " filename = self.df[\"filename\"][idx]\n",
92
+ " text = self.df[\"text\"][idx]\n",
93
+ " # prepare image (i.e. resize + normalize)\n",
94
+ " image = Image.open(self.images_dir / filename).convert(\"RGB\")\n",
95
+ " pixel_values = self.feature_extractor(image, return_tensors=\"pt\").pixel_values\n",
96
+ " # add labels (input_ids) by encoding the text\n",
97
+ " labels = self.tokenizer(\n",
98
+ " text,\n",
99
+ " padding=\"max_length\",\n",
100
+ " truncation=True,\n",
101
+ " max_length=self.max_target_length,\n",
102
+ " ).input_ids\n",
103
+ " # important: make sure that PAD tokens are ignored by the loss function\n",
104
+ " labels = [\n",
105
+ " label if label != self.tokenizer.pad_token_id else -100 for label in labels\n",
106
+ " ]\n",
107
+ "\n",
108
+ " encoding = {\n",
109
+ " \"pixel_values\": pixel_values.squeeze(),\n",
110
+ " \"labels\": torch.tensor(labels),\n",
111
+ " }\n",
112
+ " return encoding\n",
113
+ "\n",
114
+ "\n",
115
+ "\n",
116
+ "def predict(image, max_length=64, num_beams=4):\n",
117
+ "\n",
118
+ " pixel_values = feature_extractor(images=image, return_tensors=\"pt\").pixel_values\n",
119
+ " pixel_values = pixel_values.to(device)\n",
120
+ "\n",
121
+ " with torch.no_grad():\n",
122
+ " output_ids = model.generate(\n",
123
+ " pixel_values,\n",
124
+ " max_length=max_length,\n",
125
+ " num_beams=num_beams,\n",
126
+ " return_dict_in_generate=True,\n",
127
+ " ).sequences\n",
128
+ "\n",
129
+ " preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)\n",
130
+ " preds = [pred.strip() for pred in preds]\n",
131
+ "\n",
132
+ " return preds\n"
133
+ ]
134
+ },
135
+ {
136
+ "cell_type": "code",
137
+ "execution_count": null,
138
+ "id": "ea66826b",
139
+ "metadata": {
140
+ "ExecuteTime": {
141
+ "end_time": "2021-12-09T15:34:16.042990Z",
142
+ "start_time": "2021-12-09T15:34:15.980557Z"
143
+ }
144
+ },
145
+ "outputs": [],
146
+ "source": [
147
+ "data_dir = Path(\"datasets\").resolve()\n",
148
+ "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
149
+ "print(device)"
150
+ ]
151
+ },
152
+ {
153
+ "cell_type": "code",
154
+ "execution_count": null,
155
+ "id": "17cfb2c2",
156
+ "metadata": {
157
+ "ExecuteTime": {
158
+ "end_time": "2021-12-09T15:34:16.058421Z",
159
+ "start_time": "2021-12-09T15:34:16.044111Z"
160
+ }
161
+ },
162
+ "outputs": [],
163
+ "source": [
164
+ "# arguments pertaining to what data we are going to input our model for training and eval.\n",
165
+ "\n",
166
+ "data_training_args = {\n",
167
+ " # The maximum total sequence length for target text after tokenization. Sequences longer than this will be truncated, sequences shorter will be padded.\n",
168
+ " \"max_target_length\": 64,\n",
169
+ "\n",
170
+ " # Number of beams to use for evaluation. This argument will be passed to model.generate which is used during evaluate and predict.\n",
171
+ " \"num_beams\": 4,\n",
172
+ "\n",
173
+ " # Folder with all the images\n",
174
+ " \"images_dir\": data_dir / \"images\",\n",
175
+ "}\n",
176
+ "\n",
177
+ "data_training_args = Box(data_training_args)"
178
+ ]
179
+ },
180
+ {
181
+ "cell_type": "code",
182
+ "execution_count": null,
183
+ "id": "adc4839a",
184
+ "metadata": {
185
+ "ExecuteTime": {
186
+ "end_time": "2021-12-09T15:34:16.073242Z",
187
+ "start_time": "2021-12-09T15:34:16.059354Z"
188
+ }
189
+ },
190
+ "outputs": [],
191
+ "source": [
192
+ "# arguments pertaining to which model/config/tokenizer we are going to fine-tune from.\n",
193
+ "\n",
194
+ "model_args = {\n",
195
+ "\n",
196
+ " # Path to pretrained model or model identifier from huggingface.co/models\"\n",
197
+ " \"encoder_model_name_or_path\": \"google/vit-base-patch16-224-in21k\",\n",
198
+ "\n",
199
+ " # Path to pretrained model or model identifier from huggingface.co/models\"\n",
200
+ " \"decoder_model_name_or_path\": \"gpt2\",\n",
201
+ "\n",
202
+ " # If set to int > 0, all ngrams of that size can only occur once.\n",
203
+ " \"no_repeat_ngram_size\": 3,\n",
204
+ "\n",
205
+ " # Exponential penalty to the length that will be used by default in the generate method of the model.\n",
206
+ " \"length_penalty\": 2.0,\n",
207
+ "}\n",
208
+ "\n",
209
+ "model_args = Box(model_args)"
210
+ ]
211
+ },
212
+ {
213
+ "cell_type": "code",
214
+ "execution_count": null,
215
+ "id": "22b8c9e3",
216
+ "metadata": {
217
+ "ExecuteTime": {
218
+ "end_time": "2021-12-09T15:34:16.089201Z",
219
+ "start_time": "2021-12-09T15:34:16.074223Z"
220
+ }
221
+ },
222
+ "outputs": [],
223
+ "source": [
224
+ "# arguments pertaining to Trainer class. Refer: https://huggingface.co/docs/transformers/main_classes/trainer#transformers.Seq2SeqTrainingArguments\n",
225
+ "\n",
226
+ "training_args = {\n",
227
+ " \"num_train_epochs\": 5,\n",
228
+ " \"per_device_train_batch_size\": 32,\n",
229
+ " \"per_device_eval_batch_size\": 32,\n",
230
+ " \"output_dir\": \"output_dir\",\n",
231
+ " \"do_train\": True,\n",
232
+ " \"do_eval\": True,\n",
233
+ " \"fp16\": True,\n",
234
+ " \"learning_rate\": 1e-5,\n",
235
+ " \"load_best_model_at_end\": True,\n",
236
+ " \"evaluation_strategy\": \"epoch\",\n",
237
+ " \"save_strategy\": \"epoch\",\n",
238
+ " \"report_to\": \"none\"\n",
239
+ "}\n",
240
+ "\n",
241
+ "seq2seq_training_args = Seq2SeqTrainingArguments(**training_args)"
242
+ ]
243
+ },
244
+ {
245
+ "cell_type": "code",
246
+ "execution_count": null,
247
+ "id": "d0023eac",
248
+ "metadata": {
249
+ "ExecuteTime": {
250
+ "end_time": "2021-12-09T15:34:37.844396Z",
251
+ "start_time": "2021-12-09T15:34:16.090085Z"
252
+ }
253
+ },
254
+ "outputs": [],
255
+ "source": [
256
+ "feature_extractor = ViTFeatureExtractor.from_pretrained(\n",
257
+ " model_args.encoder_model_name_or_path\n",
258
+ ")\n",
259
+ "tokenizer = AutoTokenizer.from_pretrained(\n",
260
+ " model_args.decoder_model_name_or_path, use_fast=True\n",
261
+ ")\n",
262
+ "tokenizer.pad_token = tokenizer.eos_token\n",
263
+ "\n",
264
+ "model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(\n",
265
+ " model_args.encoder_model_name_or_path, model_args.decoder_model_name_or_path\n",
266
+ ")\n",
267
+ "\n",
268
+ "# set special tokens used for creating the decoder_input_ids from the labels\n",
269
+ "model.config.decoder_start_token_id = tokenizer.bos_token_id\n",
270
+ "model.config.pad_token_id = tokenizer.pad_token_id\n",
271
+ "# make sure vocab size is set correctly\n",
272
+ "model.config.vocab_size = model.config.decoder.vocab_size\n",
273
+ "\n",
274
+ "# set beam search parameters\n",
275
+ "model.config.eos_token_id = tokenizer.sep_token_id\n",
276
+ "model.config.max_length = data_training_args.max_target_length\n",
277
+ "model.config.no_repeat_ngram_size = model_args.no_repeat_ngram_size\n",
278
+ "model.config.length_penalty = model_args.length_penalty\n",
279
+ "model.config.num_beams = data_training_args.num_beams\n",
280
+ "model.decoder.resize_token_embeddings(len(tokenizer))\n"
281
+ ]
282
+ },
283
+ {
284
+ "cell_type": "code",
285
+ "execution_count": null,
286
+ "id": "6428ea08",
287
+ "metadata": {
288
+ "ExecuteTime": {
289
+ "end_time": "2021-12-09T15:34:37.933804Z",
290
+ "start_time": "2021-12-09T15:34:37.845607Z"
291
+ }
292
+ },
293
+ "outputs": [],
294
+ "source": [
295
+ "train_df = pd.read_csv(data_dir / \"train.csv\")\n",
296
+ "valid_df = pd.read_csv(data_dir / \"valid.csv\")\n",
297
+ "\n",
298
+ "train_dataset = ImageCaptionDataset(\n",
299
+ " df=train_df,\n",
300
+ " feature_extractor=feature_extractor,\n",
301
+ " tokenizer=tokenizer,\n",
302
+ " images_dir=data_training_args.images_dir,\n",
303
+ " max_target_length=data_training_args.max_target_length,\n",
304
+ ")\n",
305
+ "eval_dataset = ImageCaptionDataset(\n",
306
+ " df=valid_df,\n",
307
+ " feature_extractor=feature_extractor,\n",
308
+ " tokenizer=tokenizer,\n",
309
+ " images_dir=data_training_args.images_dir,\n",
310
+ " max_target_length=data_training_args.max_target_length,\n",
311
+ ")\n",
312
+ "\n",
313
+ "print(f\"Number of training examples: {len(train_dataset)}\")\n",
314
+ "print(f\"Number of validation examples: {len(eval_dataset)}\")"
315
+ ]
316
+ },
317
+ {
318
+ "cell_type": "code",
319
+ "execution_count": null,
320
+ "id": "c8e492a1",
321
+ "metadata": {
322
+ "ExecuteTime": {
323
+ "end_time": "2021-12-09T15:34:37.971630Z",
324
+ "start_time": "2021-12-09T15:34:37.935339Z"
325
+ }
326
+ },
327
+ "outputs": [],
328
+ "source": [
329
+ "# Let's verify an example from the training dataset:\n",
330
+ "\n",
331
+ "encoding = train_dataset[0]\n",
332
+ "for k,v in encoding.items():\n",
333
+ " print(k, v.shape)"
334
+ ]
335
+ },
336
+ {
337
+ "cell_type": "code",
338
+ "execution_count": null,
339
+ "id": "edb4e7a6",
340
+ "metadata": {
341
+ "ExecuteTime": {
342
+ "end_time": "2021-12-09T15:34:38.006980Z",
343
+ "start_time": "2021-12-09T15:34:37.972483Z"
344
+ }
345
+ },
346
+ "outputs": [],
347
+ "source": [
348
+ "# We can also check the original image and decode the labels:\n",
349
+ "image = Image.open(data_training_args.images_dir / train_df[\"filename\"][0]).convert(\"RGB\")\n",
350
+ "image"
351
+ ]
352
+ },
353
+ {
354
+ "cell_type": "code",
355
+ "execution_count": null,
356
+ "id": "25f2cae7",
357
+ "metadata": {
358
+ "ExecuteTime": {
359
+ "end_time": "2021-12-09T15:34:38.031745Z",
360
+ "start_time": "2021-12-09T15:34:38.008027Z"
361
+ }
362
+ },
363
+ "outputs": [],
364
+ "source": [
365
+ "labels = encoding[\"labels\"]\n",
366
+ "labels[labels == -100] = tokenizer.pad_token_id\n",
367
+ "label_str = tokenizer.decode(labels, skip_special_tokens=True)\n",
368
+ "print(label_str)\n"
369
+ ]
370
+ },
371
+ {
372
+ "cell_type": "code",
373
+ "execution_count": null,
374
+ "id": "b7a009d3",
375
+ "metadata": {
376
+ "ExecuteTime": {
377
+ "end_time": "2021-12-09T15:34:38.049539Z",
378
+ "start_time": "2021-12-09T15:34:38.032749Z"
379
+ }
380
+ },
381
+ "outputs": [],
382
+ "source": [
383
+ "optimizer = AdamW(model.parameters(), lr=seq2seq_training_args.learning_rate)\n",
384
+ "\n",
385
+ "steps_per_epoch = len(train_dataset) // seq2seq_training_args.per_device_train_batch_size\n",
386
+ "num_training_steps = steps_per_epoch * seq2seq_training_args.num_train_epochs\n",
387
+ "\n",
388
+ "lr_scheduler = get_linear_schedule_with_warmup(\n",
389
+ " optimizer,\n",
390
+ " num_warmup_steps=seq2seq_training_args.warmup_steps,\n",
391
+ " num_training_steps=num_training_steps,\n",
392
+ ")\n",
393
+ "\n",
394
+ "optimizers = (optimizer, lr_scheduler)"
395
+ ]
396
+ },
397
+ {
398
+ "cell_type": "code",
399
+ "execution_count": null,
400
+ "id": "f2f477b2",
401
+ "metadata": {
402
+ "ExecuteTime": {
403
+ "start_time": "2021-12-09T15:34:14.944Z"
404
+ }
405
+ },
406
+ "outputs": [],
407
+ "source": [
408
+ "trainer = Seq2SeqTrainer(\n",
409
+ " model=model,\n",
410
+ " optimizers=optimizers,\n",
411
+ " tokenizer=feature_extractor,\n",
412
+ " args=seq2seq_training_args,\n",
413
+ " train_dataset=train_dataset,\n",
414
+ " eval_dataset=eval_dataset,\n",
415
+ " data_collator=default_data_collator,\n",
416
+ ")\n",
417
+ "\n",
418
+ "trainer.train()"
419
+ ]
420
+ },
421
+ {
422
+ "cell_type": "code",
423
+ "execution_count": null,
424
+ "id": "f08d2b7c",
425
+ "metadata": {
426
+ "ExecuteTime": {
427
+ "end_time": "2021-12-09T16:24:49.096274Z",
428
+ "start_time": "2021-12-09T16:24:49.096246Z"
429
+ }
430
+ },
431
+ "outputs": [],
432
+ "source": [
433
+ "test_img = \"../examples/tt7991608-red-notice.jpg\"\n",
434
+ "with Image.open(test_img) as image:\n",
435
+ " preds = predict(\n",
436
+ " image, max_length=data_training_args.max_target_length, num_beams=data_training_args.num_beams\n",
437
+ " )\n",
438
+ "\n",
439
+ "# Uncomment to display the test image in a jupyter notebook\n",
440
+ "# display(image)\n",
441
+ "print(preds[0])"
442
+ ]
443
+ },
444
+ {
445
+ "cell_type": "code",
446
+ "execution_count": null,
447
+ "id": "ecf21225",
448
+ "metadata": {},
449
+ "outputs": [],
450
+ "source": []
451
+ }
452
+ ],
453
+ "metadata": {
454
+ "kernelspec": {
455
+ "display_name": "huggingface",
456
+ "language": "python",
457
+ "name": "huggingface"
458
+ },
459
+ "language_info": {
460
+ "codemirror_mode": {
461
+ "name": "ipython",
462
+ "version": 3
463
+ },
464
+ "file_extension": ".py",
465
+ "mimetype": "text/x-python",
466
+ "name": "python",
467
+ "nbconvert_exporter": "python",
468
+ "pygments_lexer": "ipython3",
469
+ "version": "3.9.7"
470
+ }
471
+ },
472
+ "nbformat": 4,
473
+ "nbformat_minor": 5
474
+ }