Pedro Cuenca commited on
Commit
6047b49
1 Parent(s): ecf5f29

Notebooks that demonstrate streaming encoding

Browse files

Using either Huggingface Datasets, or webdataset.

Note that parallel processing is not possible for Huggingface Datasets
in streaming mode. A local copy or the use of webdataset are preferred
for large streaming datasets.

dev/encoding/vqgan-jax-encoding-streaming.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
dev/encoding/vqgan-jax-encoding-webdataset.ipynb ADDED
@@ -0,0 +1,408 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "d0b72877",
6
+ "metadata": {},
7
+ "source": [
8
+ "# vqgan-jax-encoding-alamy"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "markdown",
13
+ "id": "ba7b31e6",
14
+ "metadata": {},
15
+ "source": [
16
+ "Encoding notebook for Alamy dataset."
17
+ ]
18
+ },
19
+ {
20
+ "cell_type": "code",
21
+ "execution_count": 1,
22
+ "id": "3b59489e",
23
+ "metadata": {},
24
+ "outputs": [],
25
+ "source": [
26
+ "import numpy as np\n",
27
+ "from tqdm import tqdm\n",
28
+ "\n",
29
+ "import torch\n",
30
+ "import torchvision.transforms as T\n",
31
+ "import torchvision.transforms.functional as TF\n",
32
+ "from torchvision.transforms import InterpolationMode\n",
33
+ "import math\n",
34
+ "\n",
35
+ "import webdataset as wds\n",
36
+ "\n",
37
+ "import jax\n",
38
+ "from jax import pmap"
39
+ ]
40
+ },
41
+ {
42
+ "cell_type": "markdown",
43
+ "id": "c7c4c1e6",
44
+ "metadata": {},
45
+ "source": [
46
+ "## Dataset and Parameters"
47
+ ]
48
+ },
49
+ {
50
+ "cell_type": "code",
51
+ "execution_count": null,
52
+ "id": "13c6631b",
53
+ "metadata": {},
54
+ "outputs": [],
55
+ "source": [
56
+ "shards = 'https://s3.us-west-1.wasabisys.com/doodlebot-wasabi/datasets/alamy/webdataset/alamy-{000..895}.tar'\n",
57
+ "\n",
58
+ "# Enable curl retries to try to work around temporary network / server errors.\n",
59
+ "# This shouldn't be necessary when using reliable servers.\n",
60
+ "shards = f'pipe:curl -s --retry 5 --retry-delay 5 -L {shards} || true'\n",
61
+ "\n",
62
+ "length = 44710810 # estimate\n",
63
+ "\n",
64
+ "from pathlib import Path\n",
65
+ "\n",
66
+ "# Output directory for encoded files\n",
67
+ "encoded_output = Path.home()/'data'/'alamy'/'encoded'\n",
68
+ "\n",
69
+ "batch_size = 128 # Per device\n",
70
+ "num_workers = 8 # Using larger numbers seemed to be less reliable in this case."
71
+ ]
72
+ },
73
+ {
74
+ "cell_type": "code",
75
+ "execution_count": 3,
76
+ "id": "3435fb85",
77
+ "metadata": {},
78
+ "outputs": [],
79
+ "source": [
80
+ "bs = batch_size * jax.device_count() # Use a smaller size for testing\n",
81
+ "batches = math.ceil(length / bs)"
82
+ ]
83
+ },
84
+ {
85
+ "cell_type": "code",
86
+ "execution_count": 4,
87
+ "id": "669b35df",
88
+ "metadata": {},
89
+ "outputs": [],
90
+ "source": [
91
+ "def center_crop(image, max_size=256):\n",
92
+ " # Note: we allow upscaling too. We should exclude small images. \n",
93
+ " image = TF.resize(image, max_size, interpolation=InterpolationMode.LANCZOS)\n",
94
+ " image = TF.center_crop(image, output_size=2 * [max_size])\n",
95
+ " return image\n",
96
+ "\n",
97
+ "preprocess_image = T.Compose([\n",
98
+ " center_crop,\n",
99
+ " T.ToTensor(),\n",
100
+ " lambda t: t.permute(1, 2, 0) # Reorder, we need dimensions last\n",
101
+ "])\n",
102
+ "\n",
103
+ "# Is there a shortcut for this?\n",
104
+ "def extract_from_json(item):\n",
105
+ " item['caption'] = item['json']['caption']\n",
106
+ " item['url'] = item['json']['url']\n",
107
+ " return item"
108
+ ]
109
+ },
110
+ {
111
+ "cell_type": "code",
112
+ "execution_count": 7,
113
+ "id": "369d9719",
114
+ "metadata": {},
115
+ "outputs": [],
116
+ "source": [
117
+ "# Log exceptions to a hardcoded file\n",
118
+ "def ignore_and_log(exn):\n",
119
+ " with open('errors.txt', 'a') as f:\n",
120
+ " f.write(f'{exn}\\n')\n",
121
+ " return True\n",
122
+ "\n",
123
+ "# Or simply use `wds.ignore_and_continue`\n",
124
+ "exception_handler = ignore_and_log\n",
125
+ "exception_handler = wds.warn_and_continue"
126
+ ]
127
+ },
128
+ {
129
+ "cell_type": "code",
130
+ "execution_count": 8,
131
+ "id": "5149b6d5",
132
+ "metadata": {},
133
+ "outputs": [],
134
+ "source": [
135
+ "dataset = wds.WebDataset(shards,\n",
136
+ " length=batches, # Hint so `len` is implemented\n",
137
+ " shardshuffle=False, # Keep same order for encoded files for easier bookkeeping\n",
138
+ " handler=exception_handler, # Ignore read errors instead of failing. See also: `warn_and_continue`\n",
139
+ ")\n",
140
+ "\n",
141
+ "dataset = (dataset \n",
142
+ " .decode('pil') # decode image with PIL\n",
143
+ " .map(extract_from_json)\n",
144
+ " .map_dict(jpg=preprocess_image, handler=exception_handler)\n",
145
+ " .to_tuple('url', 'jpg', 'caption') # filter to keep only url (for reference), image, caption.\n",
146
+ " .batched(bs)) # better to batch in the dataset (but we could also do it in the dataloader) - this arg does not affect speed and we could remove it"
147
+ ]
148
+ },
149
+ {
150
+ "cell_type": "code",
151
+ "execution_count": 10,
152
+ "id": "8cac98cb",
153
+ "metadata": {
154
+ "scrolled": true
155
+ },
156
+ "outputs": [
157
+ {
158
+ "name": "stdout",
159
+ "output_type": "stream",
160
+ "text": [
161
+ "CPU times: user 8min 26s, sys: 12.5 s, total: 8min 38s\n",
162
+ "Wall time: 14.4 s\n"
163
+ ]
164
+ }
165
+ ],
166
+ "source": [
167
+ "%%time\n",
168
+ "urls, images, captions = next(iter(dataset))"
169
+ ]
170
+ },
171
+ {
172
+ "cell_type": "code",
173
+ "execution_count": 7,
174
+ "id": "cd268fbf",
175
+ "metadata": {},
176
+ "outputs": [
177
+ {
178
+ "data": {
179
+ "text/plain": [
180
+ "torch.Size([1024, 256, 256, 3])"
181
+ ]
182
+ },
183
+ "execution_count": 7,
184
+ "metadata": {},
185
+ "output_type": "execute_result"
186
+ }
187
+ ],
188
+ "source": [
189
+ "images.shape"
190
+ ]
191
+ },
192
+ {
193
+ "cell_type": "markdown",
194
+ "id": "44d50a51",
195
+ "metadata": {},
196
+ "source": [
197
+ "### Torch DataLoader"
198
+ ]
199
+ },
200
+ {
201
+ "cell_type": "code",
202
+ "execution_count": 8,
203
+ "id": "e2df5e13",
204
+ "metadata": {},
205
+ "outputs": [],
206
+ "source": [
207
+ "dl = torch.utils.data.DataLoader(dataset, batch_size=None, num_workers=num_workers)"
208
+ ]
209
+ },
210
+ {
211
+ "cell_type": "markdown",
212
+ "id": "a354472b",
213
+ "metadata": {},
214
+ "source": [
215
+ "## VQGAN-JAX model"
216
+ ]
217
+ },
218
+ {
219
+ "cell_type": "code",
220
+ "execution_count": 9,
221
+ "id": "2fcf01d7",
222
+ "metadata": {},
223
+ "outputs": [],
224
+ "source": [
225
+ "from vqgan_jax.modeling_flax_vqgan import VQModel"
226
+ ]
227
+ },
228
+ {
229
+ "cell_type": "markdown",
230
+ "id": "9daa636d",
231
+ "metadata": {},
232
+ "source": [
233
+ "We'll use a VQGAN trained with Taming Transformers and converted to a JAX model."
234
+ ]
235
+ },
236
+ {
237
+ "cell_type": "code",
238
+ "execution_count": 10,
239
+ "id": "47a8b818",
240
+ "metadata": {
241
+ "scrolled": true
242
+ },
243
+ "outputs": [
244
+ {
245
+ "name": "stdout",
246
+ "output_type": "stream",
247
+ "text": [
248
+ "Working with z of shape (1, 256, 16, 16) = 65536 dimensions.\n"
249
+ ]
250
+ }
251
+ ],
252
+ "source": [
253
+ "model = VQModel.from_pretrained(\"flax-community/vqgan_f16_16384\")"
254
+ ]
255
+ },
256
+ {
257
+ "cell_type": "markdown",
258
+ "id": "62ad01c3",
259
+ "metadata": {},
260
+ "source": [
261
+ "## Encoding"
262
+ ]
263
+ },
264
+ {
265
+ "cell_type": "markdown",
266
+ "id": "20357f74",
267
+ "metadata": {},
268
+ "source": [
269
+ "Encoding is really simple using `shard` to automatically distribute \"superbatches\" across devices, and `pmap`. This is all it takes to create our encoding function, that will be jitted on first use."
270
+ ]
271
+ },
272
+ {
273
+ "cell_type": "code",
274
+ "execution_count": 11,
275
+ "id": "6686b004",
276
+ "metadata": {},
277
+ "outputs": [],
278
+ "source": [
279
+ "from flax.training.common_utils import shard\n",
280
+ "from functools import partial"
281
+ ]
282
+ },
283
+ {
284
+ "cell_type": "code",
285
+ "execution_count": 12,
286
+ "id": "322a4619",
287
+ "metadata": {},
288
+ "outputs": [],
289
+ "source": [
290
+ "@partial(jax.pmap, axis_name=\"batch\")\n",
291
+ "def encode(batch):\n",
292
+ " # Not sure if we should `replicate` params, does not seem to have any effect\n",
293
+ " _, indices = model.encode(batch)\n",
294
+ " return indices"
295
+ ]
296
+ },
297
+ {
298
+ "cell_type": "markdown",
299
+ "id": "14375a41",
300
+ "metadata": {},
301
+ "source": [
302
+ "### Encoding loop"
303
+ ]
304
+ },
305
+ {
306
+ "cell_type": "code",
307
+ "execution_count": 13,
308
+ "id": "ff6c10d4",
309
+ "metadata": {},
310
+ "outputs": [],
311
+ "source": [
312
+ "import os\n",
313
+ "import pandas as pd\n",
314
+ "\n",
315
+ "def encode_captioned_dataset(dataloader, output_dir, save_every=14):\n",
316
+ " output_dir.mkdir(parents=True, exist_ok=True)\n",
317
+ "\n",
318
+ " # Saving strategy:\n",
319
+ " # - Create a new file every so often to prevent excessive file seeking.\n",
320
+ " # - Save each batch after processing.\n",
321
+ " # - Keep the file open until we are done with it.\n",
322
+ " file = None \n",
323
+ " for n, (urls, images, captions) in enumerate(tqdm(dataloader)):\n",
324
+ " if (n % save_every == 0):\n",
325
+ " if file is not None:\n",
326
+ " file.close()\n",
327
+ " split_num = n // save_every\n",
328
+ " file = open(output_dir/f'split_{split_num:05x}.jsonl', 'w')\n",
329
+ "\n",
330
+ " images = shard(images.numpy().squeeze())\n",
331
+ " encoded = encode(images)\n",
332
+ " encoded = encoded.reshape(-1, encoded.shape[-1])\n",
333
+ "\n",
334
+ " encoded_as_string = list(map(lambda item: np.array2string(item, separator=',', max_line_width=50000, formatter={'int':lambda x: str(x)}), encoded))\n",
335
+ " batch_df = pd.DataFrame.from_dict({\"url\": urls, \"caption\": captions, \"encoding\": encoded_as_string})\n",
336
+ " batch_df.to_json(file, orient='records', lines=True)"
337
+ ]
338
+ },
339
+ {
340
+ "cell_type": "markdown",
341
+ "id": "09ff75a3",
342
+ "metadata": {},
343
+ "source": [
344
+ "Create a new file every 318 iterations. This should produce splits of ~500 MB each, when using a total batch size of 1024."
345
+ ]
346
+ },
347
+ {
348
+ "cell_type": "code",
349
+ "execution_count": 14,
350
+ "id": "96222bb4",
351
+ "metadata": {},
352
+ "outputs": [],
353
+ "source": [
354
+ "save_every = 318"
355
+ ]
356
+ },
357
+ {
358
+ "cell_type": "code",
359
+ "execution_count": null,
360
+ "id": "7704863d",
361
+ "metadata": {},
362
+ "outputs": [
363
+ {
364
+ "name": "stderr",
365
+ "output_type": "stream",
366
+ "text": [
367
+ " 2%|█▌ | 1085/43663 [31:58<20:43:42, 1.75s/it]"
368
+ ]
369
+ }
370
+ ],
371
+ "source": [
372
+ "encode_captioned_dataset(dl, encoded_output, save_every=save_every)"
373
+ ]
374
+ },
375
+ {
376
+ "cell_type": "markdown",
377
+ "id": "8953dd84",
378
+ "metadata": {},
379
+ "source": [
380
+ "----"
381
+ ]
382
+ }
383
+ ],
384
+ "metadata": {
385
+ "interpreter": {
386
+ "hash": "db471c52d602b4f5f40ecaf278e88ccfef85c29d0a1a07185b0d51fc7acf4e26"
387
+ },
388
+ "kernelspec": {
389
+ "display_name": "Python 3 (ipykernel)",
390
+ "language": "python",
391
+ "name": "python3"
392
+ },
393
+ "language_info": {
394
+ "codemirror_mode": {
395
+ "name": "ipython",
396
+ "version": 3
397
+ },
398
+ "file_extension": ".py",
399
+ "mimetype": "text/x-python",
400
+ "name": "python",
401
+ "nbconvert_exporter": "python",
402
+ "pygments_lexer": "ipython3",
403
+ "version": "3.8.10"
404
+ }
405
+ },
406
+ "nbformat": 4,
407
+ "nbformat_minor": 5
408
+ }