Spaces:
Running
Running
Pedro Cuenca
commited on
Commit
•
86ba774
1
Parent(s):
95d2faf
* Prepend [bos] to image encodings, rename to "labels".
Browse files- model/data-pipeline.ipynb +32 -13
model/data-pipeline.ipynb
CHANGED
@@ -161,7 +161,8 @@
|
|
161 |
"source": [
|
162 |
"# Setting padding=\"max_length\" as we need fixed length inputs for jitted functions\n",
|
163 |
"max_length = 256 # Read from data_args.max_source_length\n",
|
164 |
-
"tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-cnn')"
|
|
|
165 |
]
|
166 |
},
|
167 |
{
|
@@ -178,7 +179,7 @@
|
|
178 |
" inputs, max_length=max_length, padding=\"max_length\", truncation=True, return_tensors=\"np\"\n",
|
179 |
" )\n",
|
180 |
"\n",
|
181 |
-
" model_inputs[\"
|
182 |
"\n",
|
183 |
" return model_inputs"
|
184 |
]
|
@@ -192,10 +193,10 @@
|
|
192 |
"source": [
|
193 |
"num_workers = 48 # We have 96 processors in the TPU\n",
|
194 |
"column_names = dataset.column_names\n",
|
195 |
-
"
|
196 |
-
"
|
197 |
-
"
|
198 |
-
"
|
199 |
")"
|
200 |
]
|
201 |
},
|
@@ -240,7 +241,7 @@
|
|
240 |
"text": [
|
241 |
"INFO:absl:Starting the local TPU driver.\n",
|
242 |
"INFO:absl:Unable to initialize backend 'tpu_driver': Not found: Unable to find driver in registry given worker: local://\n",
|
243 |
-
"INFO:absl:Unable to initialize backend 'gpu': Not found: Could not find registered platform with name: \"cuda\". Available platform names are:
|
244 |
]
|
245 |
}
|
246 |
],
|
@@ -257,7 +258,7 @@
|
|
257 |
"metadata": {},
|
258 |
"outputs": [],
|
259 |
"source": [
|
260 |
-
"loader = data_loader(rng,
|
261 |
]
|
262 |
},
|
263 |
{
|
@@ -279,7 +280,7 @@
|
|
279 |
{
|
280 |
"data": {
|
281 |
"text/plain": [
|
282 |
-
"dict_keys(['attention_mask', '
|
283 |
]
|
284 |
},
|
285 |
"execution_count": 13,
|
@@ -309,7 +310,7 @@
|
|
309 |
}
|
310 |
],
|
311 |
"source": [
|
312 |
-
"len(superbatch[\"
|
313 |
]
|
314 |
},
|
315 |
{
|
@@ -321,7 +322,7 @@
|
|
321 |
{
|
322 |
"data": {
|
323 |
"text/plain": [
|
324 |
-
"(8, 64,
|
325 |
]
|
326 |
},
|
327 |
"execution_count": 15,
|
@@ -330,15 +331,33 @@
|
|
330 |
}
|
331 |
],
|
332 |
"source": [
|
333 |
-
"superbatch[\"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
334 |
]
|
335 |
},
|
336 |
{
|
337 |
"cell_type": "code",
|
338 |
-
"execution_count":
|
339 |
"id": "cfe23a71",
|
340 |
"metadata": {},
|
341 |
"outputs": [],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
342 |
"source": []
|
343 |
}
|
344 |
],
|
|
|
161 |
"source": [
|
162 |
"# Setting padding=\"max_length\" as we need fixed length inputs for jitted functions\n",
|
163 |
"max_length = 256 # Read from data_args.max_source_length\n",
|
164 |
+
"tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-cnn')\n",
|
165 |
+
"image_bos = 16384 # Max token is 16383 in our VQGAN configuration"
|
166 |
]
|
167 |
},
|
168 |
{
|
|
|
179 |
" inputs, max_length=max_length, padding=\"max_length\", truncation=True, return_tensors=\"np\"\n",
|
180 |
" )\n",
|
181 |
"\n",
|
182 |
+
" model_inputs[\"labels\"] = [[image_bos] + eval(indices) for indices in examples['encoding']]\n",
|
183 |
"\n",
|
184 |
" return model_inputs"
|
185 |
]
|
|
|
193 |
"source": [
|
194 |
"num_workers = 48 # We have 96 processors in the TPU\n",
|
195 |
"column_names = dataset.column_names\n",
|
196 |
+
"input_dataset = dataset.map(preprocess_function,\n",
|
197 |
+
" remove_columns=column_names,\n",
|
198 |
+
" batched=True,\n",
|
199 |
+
" num_proc=48\n",
|
200 |
")"
|
201 |
]
|
202 |
},
|
|
|
241 |
"text": [
|
242 |
"INFO:absl:Starting the local TPU driver.\n",
|
243 |
"INFO:absl:Unable to initialize backend 'tpu_driver': Not found: Unable to find driver in registry given worker: local://\n",
|
244 |
+
"INFO:absl:Unable to initialize backend 'gpu': Not found: Could not find registered platform with name: \"cuda\". Available platform names are: Host TPU Interpreter\n"
|
245 |
]
|
246 |
}
|
247 |
],
|
|
|
258 |
"metadata": {},
|
259 |
"outputs": [],
|
260 |
"source": [
|
261 |
+
"loader = data_loader(rng, input_dataset, batch_size=super_batch_size)"
|
262 |
]
|
263 |
},
|
264 |
{
|
|
|
280 |
{
|
281 |
"data": {
|
282 |
"text/plain": [
|
283 |
+
"dict_keys(['attention_mask', 'input_ids', 'labels'])"
|
284 |
]
|
285 |
},
|
286 |
"execution_count": 13,
|
|
|
310 |
}
|
311 |
],
|
312 |
"source": [
|
313 |
+
"len(superbatch[\"labels\"])"
|
314 |
]
|
315 |
},
|
316 |
{
|
|
|
322 |
{
|
323 |
"data": {
|
324 |
"text/plain": [
|
325 |
+
"(8, 64, 257)"
|
326 |
]
|
327 |
},
|
328 |
"execution_count": 15,
|
|
|
331 |
}
|
332 |
],
|
333 |
"source": [
|
334 |
+
"superbatch[\"labels\"].shape"
|
335 |
+
]
|
336 |
+
},
|
337 |
+
{
|
338 |
+
"cell_type": "markdown",
|
339 |
+
"id": "6800153b",
|
340 |
+
"metadata": {},
|
341 |
+
"source": [
|
342 |
+
"Any image sequence should begin with `image_bos`:"
|
343 |
]
|
344 |
},
|
345 |
{
|
346 |
"cell_type": "code",
|
347 |
+
"execution_count": 16,
|
348 |
"id": "cfe23a71",
|
349 |
"metadata": {},
|
350 |
"outputs": [],
|
351 |
+
"source": [
|
352 |
+
"assert superbatch[\"labels\"][1][5][0].item() == image_bos"
|
353 |
+
]
|
354 |
+
},
|
355 |
+
{
|
356 |
+
"cell_type": "code",
|
357 |
+
"execution_count": null,
|
358 |
+
"id": "0fb899b4",
|
359 |
+
"metadata": {},
|
360 |
+
"outputs": [],
|
361 |
"source": []
|
362 |
}
|
363 |
],
|