Pedro Cuenca commited on
Commit
86ba774
1 Parent(s): 95d2faf

* Prepend [bos] to image encodings, rename to "labels".

Browse files
Files changed (1) hide show
  1. model/data-pipeline.ipynb +32 -13
model/data-pipeline.ipynb CHANGED
@@ -161,7 +161,8 @@
161
  "source": [
162
  "# Setting padding=\"max_length\" as we need fixed length inputs for jitted functions\n",
163
  "max_length = 256 # Read from data_args.max_source_length\n",
164
- "tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-cnn')"
 
165
  ]
166
  },
167
  {
@@ -178,7 +179,7 @@
178
  " inputs, max_length=max_length, padding=\"max_length\", truncation=True, return_tensors=\"np\"\n",
179
  " )\n",
180
  "\n",
181
- " model_inputs[\"eval_encoding\"] = [eval(indices) for indices in examples['encoding']]\n",
182
  "\n",
183
  " return model_inputs"
184
  ]
@@ -192,10 +193,10 @@
192
  "source": [
193
  "num_workers = 48 # We have 96 processors in the TPU\n",
194
  "column_names = dataset.column_names\n",
195
- "dataset = dataset.map(preprocess_function,\n",
196
- " remove_columns=column_names,\n",
197
- " batched=True,\n",
198
- " num_proc=48\n",
199
  ")"
200
  ]
201
  },
@@ -240,7 +241,7 @@
240
  "text": [
241
  "INFO:absl:Starting the local TPU driver.\n",
242
  "INFO:absl:Unable to initialize backend 'tpu_driver': Not found: Unable to find driver in registry given worker: local://\n",
243
- "INFO:absl:Unable to initialize backend 'gpu': Not found: Could not find registered platform with name: \"cuda\". Available platform names are: Interpreter TPU Host\n"
244
  ]
245
  }
246
  ],
@@ -257,7 +258,7 @@
257
  "metadata": {},
258
  "outputs": [],
259
  "source": [
260
- "loader = data_loader(rng, dataset, batch_size=super_batch_size)"
261
  ]
262
  },
263
  {
@@ -279,7 +280,7 @@
279
  {
280
  "data": {
281
  "text/plain": [
282
- "dict_keys(['attention_mask', 'eval_encoding', 'input_ids'])"
283
  ]
284
  },
285
  "execution_count": 13,
@@ -309,7 +310,7 @@
309
  }
310
  ],
311
  "source": [
312
- "len(superbatch[\"eval_encoding\"])"
313
  ]
314
  },
315
  {
@@ -321,7 +322,7 @@
321
  {
322
  "data": {
323
  "text/plain": [
324
- "(8, 64, 256)"
325
  ]
326
  },
327
  "execution_count": 15,
@@ -330,15 +331,33 @@
330
  }
331
  ],
332
  "source": [
333
- "superbatch[\"eval_encoding\"].shape"
 
 
 
 
 
 
 
 
334
  ]
335
  },
336
  {
337
  "cell_type": "code",
338
- "execution_count": null,
339
  "id": "cfe23a71",
340
  "metadata": {},
341
  "outputs": [],
 
 
 
 
 
 
 
 
 
 
342
  "source": []
343
  }
344
  ],
 
161
  "source": [
162
  "# Setting padding=\"max_length\" as we need fixed length inputs for jitted functions\n",
163
  "max_length = 256 # Read from data_args.max_source_length\n",
164
+ "tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-cnn')\n",
165
+ "image_bos = 16384 # Max token is 16383 in our VQGAN configuration"
166
  ]
167
  },
168
  {
 
179
  " inputs, max_length=max_length, padding=\"max_length\", truncation=True, return_tensors=\"np\"\n",
180
  " )\n",
181
  "\n",
182
+ " model_inputs[\"labels\"] = [[image_bos] + eval(indices) for indices in examples['encoding']]\n",
183
  "\n",
184
  " return model_inputs"
185
  ]
 
193
  "source": [
194
  "num_workers = 48 # We have 96 processors in the TPU\n",
195
  "column_names = dataset.column_names\n",
196
+ "input_dataset = dataset.map(preprocess_function,\n",
197
+ " remove_columns=column_names,\n",
198
+ " batched=True,\n",
199
+ " num_proc=48\n",
200
  ")"
201
  ]
202
  },
 
241
  "text": [
242
  "INFO:absl:Starting the local TPU driver.\n",
243
  "INFO:absl:Unable to initialize backend 'tpu_driver': Not found: Unable to find driver in registry given worker: local://\n",
244
+ "INFO:absl:Unable to initialize backend 'gpu': Not found: Could not find registered platform with name: \"cuda\". Available platform names are: Host TPU Interpreter\n"
245
  ]
246
  }
247
  ],
 
258
  "metadata": {},
259
  "outputs": [],
260
  "source": [
261
+ "loader = data_loader(rng, input_dataset, batch_size=super_batch_size)"
262
  ]
263
  },
264
  {
 
280
  {
281
  "data": {
282
  "text/plain": [
283
+ "dict_keys(['attention_mask', 'input_ids', 'labels'])"
284
  ]
285
  },
286
  "execution_count": 13,
 
310
  }
311
  ],
312
  "source": [
313
+ "len(superbatch[\"labels\"])"
314
  ]
315
  },
316
  {
 
322
  {
323
  "data": {
324
  "text/plain": [
325
+ "(8, 64, 257)"
326
  ]
327
  },
328
  "execution_count": 15,
 
331
  }
332
  ],
333
  "source": [
334
+ "superbatch[\"labels\"].shape"
335
+ ]
336
+ },
337
+ {
338
+ "cell_type": "markdown",
339
+ "id": "6800153b",
340
+ "metadata": {},
341
+ "source": [
342
+ "Any image sequence should begin with `image_bos`:"
343
  ]
344
  },
345
  {
346
  "cell_type": "code",
347
+ "execution_count": 16,
348
  "id": "cfe23a71",
349
  "metadata": {},
350
  "outputs": [],
351
+ "source": [
352
+ "assert superbatch[\"labels\"][1][5][0].item() == image_bos"
353
+ ]
354
+ },
355
+ {
356
+ "cell_type": "code",
357
+ "execution_count": null,
358
+ "id": "0fb899b4",
359
+ "metadata": {},
360
+ "outputs": [],
361
  "source": []
362
  }
363
  ],