Pedro Cuenca commited on
Commit
95d2faf
1 Parent(s): 16f038a

* Data preprocessing pipeline proof of concept.

Browse files
Files changed (1) hide show
  1. model/data-pipeline.ipynb +366 -0
model/data-pipeline.ipynb ADDED
@@ -0,0 +1,366 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "bf8fb38a",
6
+ "metadata": {},
7
+ "source": [
8
+ "# Data Pipeline"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "code",
13
+ "execution_count": 1,
14
+ "id": "9b83dcb9",
15
+ "metadata": {},
16
+ "outputs": [],
17
+ "source": [
18
+ "from dataclasses import dataclass, field\n",
19
+ "from pathlib import Path\n",
20
+ "\n",
21
+ "import datasets\n",
22
+ "from datasets import Dataset, load_dataset\n",
23
+ "import numpy as np\n",
24
+ "\n",
25
+ "from transformers import BartTokenizer\n",
26
+ "\n",
27
+ "from tqdm import tqdm\n",
28
+ "\n",
29
+ "import jax\n",
30
+ "import jax.numpy as jnp\n",
31
+ "\n",
32
+ "from flax.training.common_utils import shard"
33
+ ]
34
+ },
35
+ {
36
+ "cell_type": "markdown",
37
+ "id": "a661a89e",
38
+ "metadata": {},
39
+ "source": [
40
+ "File containing image paths, captions and VQGAN-encoded indices."
41
+ ]
42
+ },
43
+ {
44
+ "cell_type": "code",
45
+ "execution_count": 2,
46
+ "id": "0e84e889",
47
+ "metadata": {},
48
+ "outputs": [],
49
+ "source": [
50
+ "datafile = '/data/CC12M/images-encoded-10000.tsv' # 9999 encoded images from CC12M"
51
+ ]
52
+ },
53
+ {
54
+ "cell_type": "markdown",
55
+ "id": "7fdc640b",
56
+ "metadata": {},
57
+ "source": [
58
+ "TODO: generate train/test splits if necessary."
59
+ ]
60
+ },
61
+ {
62
+ "cell_type": "code",
63
+ "execution_count": 3,
64
+ "id": "cc6789b4",
65
+ "metadata": {},
66
+ "outputs": [
67
+ {
68
+ "name": "stderr",
69
+ "output_type": "stream",
70
+ "text": [
71
+ "Using custom data configuration default-91833df78e844785\n",
72
+ "Reusing dataset csv (/home/pedro/.cache/huggingface/datasets/csv/default-91833df78e844785/0.0.0/e138af468cb14e747fb46a19c787ffcfa5170c821476d20d5304287ce12bbc23)\n"
73
+ ]
74
+ }
75
+ ],
76
+ "source": [
77
+ "dataset = load_dataset('csv', delimiter='\\t', data_files=[datafile])"
78
+ ]
79
+ },
80
+ {
81
+ "cell_type": "code",
82
+ "execution_count": 4,
83
+ "id": "f3ed4919",
84
+ "metadata": {},
85
+ "outputs": [
86
+ {
87
+ "data": {
88
+ "text/plain": [
89
+ "DatasetDict({\n",
90
+ " train: Dataset({\n",
91
+ " features: ['image_file', 'caption', 'encoding'],\n",
92
+ " num_rows: 9999\n",
93
+ " })\n",
94
+ "})"
95
+ ]
96
+ },
97
+ "execution_count": 4,
98
+ "metadata": {},
99
+ "output_type": "execute_result"
100
+ }
101
+ ],
102
+ "source": [
103
+ "dataset"
104
+ ]
105
+ },
106
+ {
107
+ "cell_type": "code",
108
+ "execution_count": 5,
109
+ "id": "a70c7354",
110
+ "metadata": {},
111
+ "outputs": [
112
+ {
113
+ "data": {
114
+ "text/plain": [
115
+ "Dataset({\n",
116
+ " features: ['image_file', 'caption', 'encoding'],\n",
117
+ " num_rows: 9999\n",
118
+ "})"
119
+ ]
120
+ },
121
+ "execution_count": 5,
122
+ "metadata": {},
123
+ "output_type": "execute_result"
124
+ }
125
+ ],
126
+ "source": [
127
+ "dataset = dataset[\"train\"]\n",
128
+ "dataset"
129
+ ]
130
+ },
131
+ {
132
+ "cell_type": "markdown",
133
+ "id": "a73454cf",
134
+ "metadata": {},
135
+ "source": [
136
+ "We don't really need the `image_file` field for training. We'll drop it during pre-processing because we won't be able to numericalize it to a `jnp.array`, which would be required in JAX."
137
+ ]
138
+ },
139
+ {
140
+ "cell_type": "markdown",
141
+ "id": "7c0fa992",
142
+ "metadata": {},
143
+ "source": [
144
+ "## Preprocessing"
145
+ ]
146
+ },
147
+ {
148
+ "cell_type": "markdown",
149
+ "id": "a0e36582",
150
+ "metadata": {},
151
+ "source": [
152
+ "The `encoding` field contains a string representation of the encoded indices. We'll convert them to numbers. We also need to tokenize the captions."
153
+ ]
154
+ },
155
+ {
156
+ "cell_type": "code",
157
+ "execution_count": 6,
158
+ "id": "d46f6ac5",
159
+ "metadata": {},
160
+ "outputs": [],
161
+ "source": [
162
+ "# Setting padding=\"max_length\" as we need fixed length inputs for jitted functions\n",
163
+ "max_length = 256 # Read from data_args.max_source_length\n",
164
+ "tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-cnn')"
165
+ ]
166
+ },
167
+ {
168
+ "cell_type": "code",
169
+ "execution_count": 7,
170
+ "id": "4cac6643",
171
+ "metadata": {},
172
+ "outputs": [],
173
+ "source": [
174
+ "def preprocess_function(examples):\n",
175
+ " inputs = examples[\"caption\"]\n",
176
+ "# inputs = [prefix + inp for inp in inputs] # Do we need this?\n",
177
+ " model_inputs = tokenizer(\n",
178
+ " inputs, max_length=max_length, padding=\"max_length\", truncation=True, return_tensors=\"np\"\n",
179
+ " )\n",
180
+ "\n",
181
+ " model_inputs[\"eval_encoding\"] = [eval(indices) for indices in examples['encoding']]\n",
182
+ "\n",
183
+ " return model_inputs"
184
+ ]
185
+ },
186
+ {
187
+ "cell_type": "code",
188
+ "execution_count": 8,
189
+ "id": "e6a4cb91",
190
+ "metadata": {},
191
+ "outputs": [],
192
+ "source": [
193
+ "num_workers = 48 # We have 96 processors in the TPU\n",
194
+ "column_names = dataset.column_names\n",
195
+ "dataset = dataset.map(preprocess_function,\n",
196
+ " remove_columns=column_names,\n",
197
+ " batched=True,\n",
198
+ " num_proc=48\n",
199
+ ")"
200
+ ]
201
+ },
202
+ {
203
+ "cell_type": "code",
204
+ "execution_count": 9,
205
+ "id": "a9b1b467",
206
+ "metadata": {},
207
+ "outputs": [],
208
+ "source": [
209
+ "def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False):\n",
210
+ " \"\"\"\n",
211
+ " Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.\n",
212
+ " Shuffle batches if `shuffle` is `True`.\n",
213
+ " \"\"\"\n",
214
+ " steps_per_epoch = len(dataset) // batch_size\n",
215
+ "\n",
216
+ " if shuffle:\n",
217
+ " batch_idx = jax.random.permutation(rng, len(dataset))\n",
218
+ " else:\n",
219
+ " batch_idx = jnp.arange(len(dataset))\n",
220
+ "\n",
221
+ " batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.\n",
222
+ " batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))\n",
223
+ "\n",
224
+ " for idx in batch_idx:\n",
225
+ " batch = dataset[idx] \n",
226
+ " batch = {k: jnp.array(v) for k, v in batch.items()}\n",
227
+ " batch = shard(batch)\n",
228
+ " yield batch"
229
+ ]
230
+ },
231
+ {
232
+ "cell_type": "code",
233
+ "execution_count": 10,
234
+ "id": "0a628505",
235
+ "metadata": {},
236
+ "outputs": [
237
+ {
238
+ "name": "stderr",
239
+ "output_type": "stream",
240
+ "text": [
241
+ "INFO:absl:Starting the local TPU driver.\n",
242
+ "INFO:absl:Unable to initialize backend 'tpu_driver': Not found: Unable to find driver in registry given worker: local://\n",
243
+ "INFO:absl:Unable to initialize backend 'gpu': Not found: Could not find registered platform with name: \"cuda\". Available platform names are: Interpreter TPU Host\n"
244
+ ]
245
+ }
246
+ ],
247
+ "source": [
248
+ "rng = jax.random.PRNGKey(23) # Use training_args.seed\n",
249
+ "batch_size = 64 # Per device\n",
250
+ "super_batch_size = batch_size * jax.device_count()"
251
+ ]
252
+ },
253
+ {
254
+ "cell_type": "code",
255
+ "execution_count": 11,
256
+ "id": "b3a5ce7d",
257
+ "metadata": {},
258
+ "outputs": [],
259
+ "source": [
260
+ "loader = data_loader(rng, dataset, batch_size=super_batch_size)"
261
+ ]
262
+ },
263
+ {
264
+ "cell_type": "code",
265
+ "execution_count": 12,
266
+ "id": "67aa8f9c",
267
+ "metadata": {},
268
+ "outputs": [],
269
+ "source": [
270
+ "superbatch = next(iter(loader))"
271
+ ]
272
+ },
273
+ {
274
+ "cell_type": "code",
275
+ "execution_count": 13,
276
+ "id": "7cd99402",
277
+ "metadata": {},
278
+ "outputs": [
279
+ {
280
+ "data": {
281
+ "text/plain": [
282
+ "dict_keys(['attention_mask', 'eval_encoding', 'input_ids'])"
283
+ ]
284
+ },
285
+ "execution_count": 13,
286
+ "metadata": {},
287
+ "output_type": "execute_result"
288
+ }
289
+ ],
290
+ "source": [
291
+ "superbatch.keys()"
292
+ ]
293
+ },
294
+ {
295
+ "cell_type": "code",
296
+ "execution_count": 14,
297
+ "id": "652a4a9e",
298
+ "metadata": {},
299
+ "outputs": [
300
+ {
301
+ "data": {
302
+ "text/plain": [
303
+ "8"
304
+ ]
305
+ },
306
+ "execution_count": 14,
307
+ "metadata": {},
308
+ "output_type": "execute_result"
309
+ }
310
+ ],
311
+ "source": [
312
+ "len(superbatch[\"eval_encoding\"])"
313
+ ]
314
+ },
315
+ {
316
+ "cell_type": "code",
317
+ "execution_count": 15,
318
+ "id": "de7de4e8",
319
+ "metadata": {},
320
+ "outputs": [
321
+ {
322
+ "data": {
323
+ "text/plain": [
324
+ "(8, 64, 256)"
325
+ ]
326
+ },
327
+ "execution_count": 15,
328
+ "metadata": {},
329
+ "output_type": "execute_result"
330
+ }
331
+ ],
332
+ "source": [
333
+ "superbatch[\"eval_encoding\"].shape"
334
+ ]
335
+ },
336
+ {
337
+ "cell_type": "code",
338
+ "execution_count": null,
339
+ "id": "cfe23a71",
340
+ "metadata": {},
341
+ "outputs": [],
342
+ "source": []
343
+ }
344
+ ],
345
+ "metadata": {
346
+ "kernelspec": {
347
+ "display_name": "Python 3 (ipykernel)",
348
+ "language": "python",
349
+ "name": "python3"
350
+ },
351
+ "language_info": {
352
+ "codemirror_mode": {
353
+ "name": "ipython",
354
+ "version": 3
355
+ },
356
+ "file_extension": ".py",
357
+ "mimetype": "text/x-python",
358
+ "name": "python",
359
+ "nbconvert_exporter": "python",
360
+ "pygments_lexer": "ipython3",
361
+ "version": "3.8.10"
362
+ }
363
+ },
364
+ "nbformat": 4,
365
+ "nbformat_minor": 5
366
+ }